59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
import pandas as pd
|
||
import os
|
||
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
|
||
from langchain.chains import LLMChain
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain.prompts import PromptTemplate
|
||
import re
|
||
|
||
api = "xxxxxx"
|
||
sk = "xxxxxx"
|
||
|
||
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
|
||
|
||
def classify(text):
|
||
prompt = PromptTemplate(
|
||
input_variables=["prompt"],
|
||
template='''我国国民经济行业共有20个门类,如下:
|
||
农、林、牧、渔业;采矿业;制造业;电力、热力、燃气及水生产和供应业;建筑业;批发和零售业;交通运输、仓储和邮政业;
|
||
住宿和餐饮业;信息传输、软件和信息技术服务业;金融业;房地产业;租赁和商务服务业;科学研究和技术服务业;
|
||
水利、环境和公共设施管理业;居民服务、修理和其他服务业;教育;卫生和社会工作;文化、体育和娱乐业;公共管理、社会保障和社会组织;国际组织。
|
||
|
||
输出下面内容属于哪个行业类别:{prompt}。
|
||
答案用#表示,输出格式如下:
|
||
#科学研究和技术服务业#
|
||
'''
|
||
)
|
||
chain = LLMChain(llm = llm,prompt=prompt)
|
||
response = chain.run(
|
||
{"prompt":text}
|
||
)
|
||
# print(response)
|
||
|
||
return response
|
||
|
||
|
||
def normalize(text):
|
||
|
||
clean_text = re.sub(r'[\r\n]+', '', text)
|
||
pattern = r'(?<=#)(.*?)(?=#)'
|
||
matches = re.findall(pattern, clean_text)
|
||
|
||
return matches[-1]
|
||
|
||
if __name__ == "__main__":
|
||
data = pd.read_excel("/home/zhangxj/WorkFile/LCA-GPT/QA/QA.xlsx")
|
||
|
||
class_list = []
|
||
for idx,item in data.iterrows():
|
||
ques = item['question']
|
||
ans = item['answer']
|
||
query = "问题:"+ques+"\n答案:"+ans
|
||
res = classify(query)
|
||
res = normalize(res)
|
||
print(res)
|
||
class_list.append(res)
|
||
question_list = data['question'].tolist()
|
||
ans_list = data['answer'].tolist()
|
||
df = pd.DataFrame({'question': question_list, 'answer': ans_list, '类别': class_list})
|
||
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv",index = False) |