LCA-GPT/QA/classify.py

59 lines
2.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import PromptTemplate
import re
api = "xxxxxx"
sk = "xxxxxx"
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
def classify(text):
prompt = PromptTemplate(
input_variables=["prompt"],
template='''我国国民经济行业共有20个门类如下
农、林、牧、渔业;采矿业;制造业;电力、热力、燃气及水生产和供应业;建筑业;批发和零售业;交通运输、仓储和邮政业;
住宿和餐饮业;信息传输、软件和信息技术服务业;金融业;房地产业;租赁和商务服务业;科学研究和技术服务业;
水利、环境和公共设施管理业;居民服务、修理和其他服务业;教育;卫生和社会工作;文化、体育和娱乐业;公共管理、社会保障和社会组织;国际组织。
输出下面内容属于哪个行业类别:{prompt}
答案用#表示,输出格式如下:
#科学研究和技术服务业#
'''
)
chain = LLMChain(llm = llm,prompt=prompt)
response = chain.run(
{"prompt":text}
)
# print(response)
return response
def normalize(text):
clean_text = re.sub(r'[\r\n]+', '', text)
pattern = r'(?<=#)(.*?)(?=#)'
matches = re.findall(pattern, clean_text)
return matches[-1]
if __name__ == "__main__":
data = pd.read_excel("/home/zhangxj/WorkFile/LCA-GPT/QA/QA.xlsx")
class_list = []
for idx,item in data.iterrows():
ques = item['question']
ans = item['answer']
query = "问题:"+ques+"\n答案:"+ans
res = classify(query)
res = normalize(res)
print(res)
class_list.append(res)
question_list = data['question'].tolist()
ans_list = data['answer'].tolist()
df = pd.DataFrame({'question': question_list, 'answer': ans_list, '类别': class_list})
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv",index = False)