78 lines
3.7 KiB
Python
78 lines
3.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import pandas as pd
|
|
import os
|
|
from langchain_community.llms import QianfanLLMEndpoint,Tongyi
|
|
from langchain.chains import LLMChain
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain.prompts import PromptTemplate
|
|
import re
|
|
import numpy as np
|
|
|
|
api = "xxxxx"
|
|
sk = "xxxxx"
|
|
|
|
llm = QianfanLLMEndpoint(model="ERNIE-3.5-8K",qianfan_ak=api,qianfan_sk=sk)
|
|
|
|
def classify(text):
|
|
prompt = PromptTemplate(
|
|
input_variables=["prompt"],
|
|
template='''下面给出国民经济行业分类及其子类.
|
|
农、林、牧、渔业包括:农业;林业;畜牧业;渔业.
|
|
科学研究和技术服务业包括:研究和试验发展;专业技术服务业;科技推广和应用服务业;
|
|
制造业包括:农副食品加工业;食品制造业;酒、饮料和精制茶制造业;烟草制品业;纺织业;纺织服装、服饰业;皮革、毛皮、羽毛及其制品和制鞋业;木材加工和木、竹、藤、棕、草制品业;家具制造业;造纸和纸制品业;印刷和记录媒介复制业;文教、工美、体育和娱乐用品制造业;石油、煤炭及其他燃料加工业;化学原料和化学制品制造业;医药制造业;化学纤维制造业;橡胶和塑料制品业;非金属矿物制品业;黑色金属冶炼和压延加工业;有色金属冶炼和压延加工业;金属制品业;通用设备制造业;专用设备制造业;汽车制造业; 铁路、船舶、航空航天和其他运输设备制造业;电气机械和器材制造业;计算机、通信和其他电子设备制造业;仪器仪表制造业;废弃资源综合利用业;金属制品、机械和设备修理业.
|
|
水利、环境和公共设施管理业包括:水利管理业、生态保护和环境治理业;公共设施管理业;土地管理业;
|
|
电力、热力、燃气及水生产和供应业包括:电力、热力生产和供应业;燃气生产和供应业;水的生产和供应业;
|
|
建筑业包括:房屋建筑业;土木工程建筑业;建筑安装业;建筑装饰、装修和其他建筑业
|
|
|
|
输出下面内容属于哪个具体的子类:{prompt}。
|
|
答案用#表示,输出格式如下:
|
|
#农副食品加工业#
|
|
'''
|
|
)
|
|
chain = LLMChain(llm = llm,prompt=prompt)
|
|
response = chain.run(
|
|
{"prompt":text}
|
|
)
|
|
# print(response)
|
|
|
|
return response
|
|
|
|
|
|
def normalize(text):
|
|
|
|
clean_text = re.sub(r'[\r\n]+', '', text)
|
|
pattern = r'(?<=#)(.*?)(?=#)'
|
|
matches = re.findall(pattern, clean_text)
|
|
|
|
if len(matches) == 0:
|
|
return clean_text
|
|
else:
|
|
return matches[-1]
|
|
|
|
if __name__ == "__main__":
|
|
data = pd.read_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify.csv")
|
|
tar = ["农、林、牧、渔业","科学研究和技术服务业","制造业","水利、环境和公共设施管理业","电力、热力、燃气及水生产和供应业","建筑业"]
|
|
class_list = []
|
|
for idx,item in data.iterrows():
|
|
ques = item['question']
|
|
ans = item['answer']
|
|
c = item["type"]
|
|
# print(c_next)
|
|
if c in tar:
|
|
query = "问题:"+ques+"\n答案:"+ans
|
|
res = classify(query)
|
|
res = normalize(res)
|
|
print(res)
|
|
class_list.append(res)
|
|
else:
|
|
class_list.append(c)
|
|
question_list = data['question'].tolist()
|
|
ans_list = data['answer'].tolist()
|
|
c_list = data['type'].tolist() # question,answer,type
|
|
df = pd.DataFrame({'question': question_list,
|
|
'answer': ans_list,
|
|
'类别':c_list,
|
|
'子类别': class_list})
|
|
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/classify_new.csv",index = False) |