149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
'''
|
||
支持上传文件进行分析,记录历史上下文,
|
||
'''
|
||
import time
|
||
import tempfile
|
||
from dotenv import load_dotenv
|
||
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain.prompts import MessagesPlaceholder
|
||
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
|
||
from langchain_core.messages import AIMessage, HumanMessage
|
||
from langchain.chains import create_retrieval_chain
|
||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||
# from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||
from langchain_community.chat_models import QianfanChatEndpoint
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain.retrievers import ContextualCompressionRetriever
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
from langchain_chroma import Chroma
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||
import requests
|
||
import os
|
||
import re
|
||
import pandas as pd
|
||
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
device = "cuda"
|
||
memory = 6
|
||
|
||
def init_chain():
|
||
try:
|
||
# 加载环境变量并初始化模型
|
||
|
||
# 使用Tongyi
|
||
# os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
|
||
# llm = ChatTongyi(
|
||
# streaming=True,
|
||
# model='llama3-70b-instruct'
|
||
# )
|
||
|
||
# 使用Qianfan
|
||
# ak = "beiDmKGviUVjTyGqb8U0dc2u"
|
||
# sk = "i3GexGV2B5Poi00RVAFUoZ61Ylj0P7tI"
|
||
# llm = QianfanChatEndpoint(
|
||
# streaming=True,
|
||
# model = 'ERNIE-3.5-8K',
|
||
# api_key=ak,
|
||
# secret_key=sk
|
||
# )
|
||
|
||
# 使用智谱
|
||
# os.environ["ZHIPUAI_API_KEY"] = "434790cf952335f18b6347e7b6de9777.V50p55zfk8Ye4ojV"
|
||
# llm = ChatZhipuAI(
|
||
# streaming=True,
|
||
# model = "glm-4"
|
||
# )
|
||
|
||
# deepseek
|
||
# pip3 install langchain_openai
|
||
# python3 deepseek_v2_langchain.py
|
||
|
||
llm = ChatOpenAI(
|
||
model='deepseek-chat',
|
||
openai_api_key='sk-c47f0552d66549ac92c72d64ebdc4d05',
|
||
openai_api_base='https://api.deepseek.com',
|
||
)
|
||
|
||
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
|
||
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
|
||
|
||
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
||
|
||
instruct_system_prompt = (
|
||
"你是生命周期领域富有经验和知识的专家。"
|
||
"使用以下检索到的上下文来回答问题。"
|
||
"{context}"
|
||
"答案最多使用1句话并保持非常简洁,不能换行。"
|
||
)
|
||
# instr = "你是生命周期领域富有经验和知识的专家。根据你所掌握的知识只用1句话回答问题。不要列出几点来回答,不需要换行"
|
||
instruct_prompt = ChatPromptTemplate.from_messages(
|
||
[
|
||
("system", instruct_system_prompt),
|
||
("human", "{input}"),
|
||
]
|
||
)
|
||
|
||
# Create a chain for passing a list of Documents to a model.
|
||
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
|
||
rag_chain = create_retrieval_chain(retriever, qa_chain)
|
||
|
||
return rag_chain,retriever
|
||
|
||
except Exception as e:
|
||
print(f"Error in init_chain: {e}")
|
||
return None
|
||
|
||
def user_in(uin, rag_chain):
|
||
try:
|
||
result = rag_chain.invoke({"input": uin})['answer']
|
||
# print(rag_chain)
|
||
# result = rag_chain(uin)
|
||
return result
|
||
except Exception as e:
|
||
print(f"Error in user_in: {e}")
|
||
return "An error occurred while processing your request."
|
||
|
||
def retrieve_and_output(retriever,query):
|
||
try:
|
||
retrieved_docs = retriever.get_relevant_documents(query)
|
||
# 输出检索到的内容
|
||
print("Retrieved content:")
|
||
for doc in retrieved_docs:
|
||
print(doc.page_content)
|
||
except Exception as e:
|
||
print(f"Error in retrieve_and_output: {e}")
|
||
|
||
def main():
|
||
|
||
question = []
|
||
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/split/question/ques0.txt","r",encoding="utf-8") as file:
|
||
for line in file.readlines():
|
||
question.append(line.strip())
|
||
|
||
rag,retriever= init_chain()
|
||
|
||
answers = []
|
||
for ques in question:
|
||
# print(ques)
|
||
response = user_in(ques, rag)
|
||
retrieve_and_output(retriever,ques)
|
||
print(response)
|
||
# print(len(response))
|
||
answers.append(response.strip())
|
||
# if len(answers) == 3:
|
||
# break
|
||
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/RAGpred.txt","w",encoding="utf-8") as file:
|
||
for ans in answers:
|
||
line = re.sub(r'\s+', '', ans)
|
||
file.write(line+'\n')
|
||
data = {"ans":answers}
|
||
df = pd.DataFrame(data)
|
||
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/rag.csv",index=False)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|