59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
|
import pandas as pd
|
|||
|
import os
|
|||
|
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
|
|||
|
from langchain.chains import LLMChain
|
|||
|
from langchain_core.prompts import ChatPromptTemplate
|
|||
|
from langchain.prompts import PromptTemplate
|
|||
|
import re
|
|||
|
|
|||
|
api = "xxxxxx"
|
|||
|
sk = "xxxxxx"
|
|||
|
|
|||
|
llm = QianfanLLMEndpoint(model="ERNIE-4.0-8K",qianfan_ak=api,qianfan_sk=sk)
|
|||
|
|
|||
|
def classify(text):
|
|||
|
prompt = PromptTemplate(
|
|||
|
input_variables=["prompt"],
|
|||
|
template='''我国国民经济行业共有20个门类,如下:
|
|||
|
农、林、牧、渔业;采矿业;制造业;电力、热力、燃气及水生产和供应业;建筑业;批发和零售业;交通运输、仓储和邮政业;
|
|||
|
住宿和餐饮业;信息传输、软件和信息技术服务业;金融业;房地产业;租赁和商务服务业;科学研究和技术服务业;
|
|||
|
水利、环境和公共设施管理业;居民服务、修理和其他服务业;教育;卫生和社会工作;文化、体育和娱乐业;公共管理、社会保障和社会组织;国际组织。
|
|||
|
|
|||
|
输出下面内容属于哪个行业类别:{prompt}。
|
|||
|
答案用#表示,输出格式如下:
|
|||
|
#科学研究和技术服务业#
|
|||
|
'''
|
|||
|
)
|
|||
|
chain = LLMChain(llm = llm,prompt=prompt)
|
|||
|
response = chain.run(
|
|||
|
{"prompt":text}
|
|||
|
)
|
|||
|
# print(response)
|
|||
|
|
|||
|
return response
|
|||
|
|
|||
|
|
|||
|
def normalize(text):
|
|||
|
|
|||
|
clean_text = re.sub(r'[\r\n]+', '', text)
|
|||
|
pattern = r'(?<=#)(.*?)(?=#)'
|
|||
|
matches = re.findall(pattern, clean_text)
|
|||
|
|
|||
|
return matches[-1]
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
data = pd.read_excel("/home/zhangxj/WorkFile/LCA-GPT/QA/QA.xlsx")
|
|||
|
|
|||
|
class_list = []
|
|||
|
for idx,item in data.iterrows():
|
|||
|
ques = item['question']
|
|||
|
ans = item['answer']
|
|||
|
query = "问题:"+ques+"\n答案:"+ans
|
|||
|
res = classify(query)
|
|||
|
res = normalize(res)
|
|||
|
print(res)
|
|||
|
class_list.append(res)
|
|||
|
question_list = data['question'].tolist()
|
|||
|
ans_list = data['answer'].tolist()
|
|||
|
df = pd.DataFrame({'question': question_list, 'answer': ans_list, '类别': class_list})
|
|||
|
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv",index = False)
|