import pandas as pd import os from langchain_community.llms import QianfanLLMEndpoint,Tongyi from langchain.chains import LLMChain from langchain_core.prompts import ChatPromptTemplate from langchain.prompts import PromptTemplate import re api = "xxxxxx" sk = "xxxxxx" llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk) def classify(text): prompt = PromptTemplate( input_variables=["prompt"], template='''我国国民经济行业共有20个门类,如下: 农、林、牧、渔业;采矿业;制造业;电力、热力、燃气及水生产和供应业;建筑业;批发和零售业;交通运输、仓储和邮政业; 住宿和餐饮业;信息传输、软件和信息技术服务业;金融业;房地产业;租赁和商务服务业;科学研究和技术服务业; 水利、环境和公共设施管理业;居民服务、修理和其他服务业;教育;卫生和社会工作;文化、体育和娱乐业;公共管理、社会保障和社会组织;国际组织。 输出下面内容属于哪个行业类别:{prompt}。 答案用#表示,输出格式如下: #科学研究和技术服务业# ''' ) chain = LLMChain(llm = llm,prompt=prompt) response = chain.run( {"prompt":text} ) # print(response) return response def normalize(text): clean_text = re.sub(r'[\r\n]+', '', text) pattern = r'(?<=#)(.*?)(?=#)' matches = re.findall(pattern, clean_text) return matches[-1] if __name__ == "__main__": data = pd.read_excel("/home/zhangxj/WorkFile/LCA-GPT/QA/QA.xlsx") class_list = [] for idx,item in data.iterrows(): ques = item['question'] ans = item['answer'] query = "问题:"+ques+"\n答案:"+ans res = classify(query) res = normalize(res) print(res) class_list.append(res) question_list = data['question'].tolist() ans_list = data['answer'].tolist() df = pd.DataFrame({'question': question_list, 'answer': ans_list, '类别': class_list}) df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv",index = False)