LCA-GPT/QA/classify_son.py

78 lines
3.7 KiB
Python

# -*- coding: utf-8 -*-
import pandas as pd
import os
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import PromptTemplate
import re
import numpy as np
api = "xxxxx"
sk = "xxxxx"
llm = QianfanLLMEndpoint(model="ERNIE-3.5-8K",qianfan_ak=api,qianfan_sk=sk)
def classify(text):
prompt = PromptTemplate(
input_variables=["prompt"],
template='''下面给出国民经济行业分类及其子类.
农、林、牧、渔业包括:农业;林业;畜牧业;渔业.
科学研究和技术服务业包括:研究和试验发展;专业技术服务业;科技推广和应用服务业;
制造业包括:农副食品加工业;食品制造业;酒、饮料和精制茶制造业;烟草制品业;纺织业;纺织服装、服饰业;皮革、毛皮、羽毛及其制品和制鞋业;木材加工和木、竹、藤、棕、草制品业;家具制造业;造纸和纸制品业;印刷和记录媒介复制业;文教、工美、体育和娱乐用品制造业;石油、煤炭及其他燃料加工业;化学原料和化学制品制造业;医药制造业;化学纤维制造业;橡胶和塑料制品业;非金属矿物制品业;黑色金属冶炼和压延加工业;有色金属冶炼和压延加工业;金属制品业;通用设备制造业;专用设备制造业;汽车制造业; 铁路、船舶、航空航天和其他运输设备制造业;电气机械和器材制造业;计算机、通信和其他电子设备制造业;仪器仪表制造业;废弃资源综合利用业;金属制品、机械和设备修理业.
水利、环境和公共设施管理业包括:水利管理业、生态保护和环境治理业;公共设施管理业;土地管理业;
电力、热力、燃气及水生产和供应业包括:电力、热力生产和供应业;燃气生产和供应业;水的生产和供应业;
建筑业包括:房屋建筑业;土木工程建筑业;建筑安装业;建筑装饰、装修和其他建筑业
输出下面内容属于哪个具体的子类:{prompt}
答案用#表示,输出格式如下:
#农副食品加工业#
'''
)
chain = LLMChain(llm = llm,prompt=prompt)
response = chain.run(
{"prompt":text}
)
# print(response)
return response
def normalize(text):
clean_text = re.sub(r'[\r\n]+', '', text)
pattern = r'(?<=#)(.*?)(?=#)'
matches = re.findall(pattern, clean_text)
if len(matches) == 0:
return clean_text
else:
return matches[-1]
if __name__ == "__main__":
data = pd.read_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv")
tar = ["农、林、牧、渔业","科学研究和技术服务业","制造业","水利、环境和公共设施管理业","电力、热力、燃气及水生产和供应业","建筑业"]
class_list = []
for idx,item in data.iterrows():
ques = item['question']
ans = item['answer']
c = item["type"]
# print(c_next)
if c in tar:
query = "问题:"+ques+"\n答案:"+ans
res = classify(query)
res = normalize(res)
print(res)
class_list.append(res)
else:
class_list.append(c)
question_list = data['question'].tolist()
ans_list = data['answer'].tolist()
c_list = data['type'].tolist() # question,answer,type
df = pd.DataFrame({'question': question_list,
'answer': ans_list,
'类别':c_list,
'子类别': class_list})
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify_new.csv",index = False)