LCA-GPT/QA/classify.py

59 lines
2.2 KiB
Python
Raw Permalink Normal View History

2024-12-29 16:18:16 +08:00
import pandas as pd
import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import PromptTemplate
import re
api = "xxxxxx"
sk = "xxxxxx"
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
def classify(text):
prompt = PromptTemplate(
input_variables=["prompt"],
template='''我国国民经济行业共有20个门类如下
渔业采矿业制造业电力热力燃气及水生产和供应业建筑业批发和零售业交通运输仓储和邮政业
住宿和餐饮业信息传输软件和信息技术服务业金融业房地产业租赁和商务服务业科学研究和技术服务业
水利环境和公共设施管理业居民服务修理和其他服务业教育卫生和社会工作文化体育和娱乐业公共管理社会保障和社会组织国际组织
输出下面内容属于哪个行业类别{prompt}
答案用#表示,输出格式如下:
#科学研究和技术服务业#
'''
)
chain = LLMChain(llm = llm,prompt=prompt)
response = chain.run(
{"prompt":text}
)
# print(response)
return response
def normalize(text):
clean_text = re.sub(r'[\r\n]+', '', text)
pattern = r'(?<=#)(.*?)(?=#)'
matches = re.findall(pattern, clean_text)
return matches[-1]
if __name__ == "__main__":
data = pd.read_excel("/home/zhangxj/WorkFile/LCA-GPT/QA/QA.xlsx")
class_list = []
for idx,item in data.iterrows():
ques = item['question']
ans = item['answer']
query = "问题:"+ques+"\n答案:"+ans
res = classify(query)
res = normalize(res)
print(res)
class_list.append(res)
question_list = data['question'].tolist()
ans_list = data['answer'].tolist()
df = pd.DataFrame({'question': question_list, 'answer': ans_list, '类别': class_list})
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv",index = False)