LCA-LLM/LCA_RAG/RAGnoUI.py

149 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
支持上传文件进行分析,记录历史上下文,
'''
import time
import tempfile
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
# from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import ChatZhipuAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import os
import re
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
def init_chain():
try:
# 加载环境变量并初始化模型
# 使用Tongyi
# os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
# llm = ChatTongyi(
# streaming=True,
# model='llama3-70b-instruct'
# )
# 使用Qianfan
# ak = "beiDmKGviUVjTyGqb8U0dc2u"
# sk = "i3GexGV2B5Poi00RVAFUoZ61Ylj0P7tI"
# llm = QianfanChatEndpoint(
# streaming=True,
# model = 'ERNIE-3.5-8K',
# api_key=ak,
# secret_key=sk
# )
# 使用智谱
# os.environ["ZHIPUAI_API_KEY"] = "434790cf952335f18b6347e7b6de9777.V50p55zfk8Ye4ojV"
# llm = ChatZhipuAI(
# streaming=True,
# model = "glm-4"
# )
# deepseek
# pip3 install langchain_openai
# python3 deepseek_v2_langchain.py
llm = ChatOpenAI(
model='deepseek-chat',
openai_api_key='sk-c47f0552d66549ac92c72d64ebdc4d05',
openai_api_base='https://api.deepseek.com',
)
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 10})
instruct_system_prompt = (
"你是生命周期领域富有经验和知识的专家。"
"使用以下检索到的上下文来回答问题。"
"{context}"
"答案最多使用1句话并保持非常简洁不能换行。"
)
# instr = "你是生命周期领域富有经验和知识的专家。根据你所掌握的知识只用1句话回答问题。不要列出几点来回答不需要换行"
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
("human", "{input}"),
]
)
# Create a chain for passing a list of Documents to a model.
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain,retriever
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def user_in(uin, rag_chain):
try:
result = rag_chain.invoke({"input": uin})['answer']
# print(rag_chain)
# result = rag_chain(uin)
return result
except Exception as e:
print(f"Error in user_in: {e}")
return "An error occurred while processing your request."
def retrieve_and_output(retriever,query):
try:
retrieved_docs = retriever.get_relevant_documents(query)
# 输出检索到的内容
print("Retrieved content:")
for doc in retrieved_docs:
print(doc.page_content)
except Exception as e:
print(f"Error in retrieve_and_output: {e}")
def main():
question = []
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/split/question/ques0.txt","r",encoding="utf-8") as file:
for line in file.readlines():
question.append(line.strip())
rag,retriever= init_chain()
answers = []
for ques in question:
# print(ques)
response = user_in(ques, rag)
retrieve_and_output(retriever,ques)
print(response)
# print(len(response))
answers.append(response.strip())
# if len(answers) == 3:
# break
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/RAGpred.txt","w",encoding="utf-8") as file:
for ans in answers:
line = re.sub(r'\s+', '', ans)
file.write(line+'\n')
data = {"ans":answers}
df = pd.DataFrame(data)
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/rag.csv",index=False)
if __name__ == "__main__":
main()