LCA-LLM/LCA_RAG/RAGnoUI.py

149 lines
5.1 KiB
Python
Raw Permalink Normal View History

2025-02-06 16:10:41 +08:00
'''
支持上传文件进行分析记录历史上下文
'''
import time
import tempfile
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
# from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import ChatZhipuAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import os
import re
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
def init_chain():
try:
# 加载环境变量并初始化模型
# 使用Tongyi
# os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
# llm = ChatTongyi(
# streaming=True,
# model='llama3-70b-instruct'
# )
# 使用Qianfan
# ak = "beiDmKGviUVjTyGqb8U0dc2u"
# sk = "i3GexGV2B5Poi00RVAFUoZ61Ylj0P7tI"
# llm = QianfanChatEndpoint(
# streaming=True,
# model = 'ERNIE-3.5-8K',
# api_key=ak,
# secret_key=sk
# )
# 使用智谱
# os.environ["ZHIPUAI_API_KEY"] = "434790cf952335f18b6347e7b6de9777.V50p55zfk8Ye4ojV"
# llm = ChatZhipuAI(
# streaming=True,
# model = "glm-4"
# )
# deepseek
# pip3 install langchain_openai
# python3 deepseek_v2_langchain.py
llm = ChatOpenAI(
model='deepseek-chat',
openai_api_key='sk-c47f0552d66549ac92c72d64ebdc4d05',
openai_api_base='https://api.deepseek.com',
)
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 10})
instruct_system_prompt = (
"你是生命周期领域富有经验和知识的专家。"
"使用以下检索到的上下文来回答问题。"
"{context}"
"答案最多使用1句话并保持非常简洁不能换行。"
)
# instr = "你是生命周期领域富有经验和知识的专家。根据你所掌握的知识只用1句话回答问题。不要列出几点来回答不需要换行"
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
("human", "{input}"),
]
)
# Create a chain for passing a list of Documents to a model.
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain,retriever
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def user_in(uin, rag_chain):
try:
result = rag_chain.invoke({"input": uin})['answer']
# print(rag_chain)
# result = rag_chain(uin)
return result
except Exception as e:
print(f"Error in user_in: {e}")
return "An error occurred while processing your request."
def retrieve_and_output(retriever,query):
try:
retrieved_docs = retriever.get_relevant_documents(query)
# 输出检索到的内容
print("Retrieved content:")
for doc in retrieved_docs:
print(doc.page_content)
except Exception as e:
print(f"Error in retrieve_and_output: {e}")
def main():
question = []
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/split/question/ques0.txt","r",encoding="utf-8") as file:
for line in file.readlines():
question.append(line.strip())
rag,retriever= init_chain()
answers = []
for ques in question:
# print(ques)
response = user_in(ques, rag)
retrieve_and_output(retriever,ques)
print(response)
# print(len(response))
answers.append(response.strip())
# if len(answers) == 3:
# break
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/RAGpred.txt","w",encoding="utf-8") as file:
for ans in answers:
line = re.sub(r'\s+', '', ans)
file.write(line+'\n')
data = {"ans":answers}
df = pd.DataFrame(data)
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/rag.csv",index=False)
if __name__ == "__main__":
main()