LCA-LLM/LCA_RAG/RAGnoUI.py

129 lines
4.3 KiB
Python
Raw Normal View History

2024-12-29 17:33:02 +08:00
'''
支持上传文件进行分析记录历史上下文
'''
import time
import tempfile
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
# from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import ChatZhipuAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import os
import re
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
def init_chain():
try:
# 加载环境变量并初始化模型
# 使用Tongyi
# os.environ["DASHSCOPE_API_KEY"] = 'xxxx'
# llm = ChatTongyi(
# streaming=True,
# model='llama3-70b-instruct'
# )
# 使用Qianfan
# ak = "xxxx"
# sk = "xxxxx"
# llm = QianfanChatEndpoint(
# streaming=True,
# model = 'ERNIE-3.5-8K',
# api_key=ak,
# secret_key=sk
# )
# 使用智谱
os.environ["ZHIPUAI_API_KEY"] = "xxxxx"
llm = ChatZhipuAI(
streaming=True,
model = "glm-4"
)
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
retrieval_chunks = 20 #前20个相似块
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 20})
instruct_system_prompt = (
"你是生命周期领域富有经验和知识的专家。"
"使用以下检索到的上下文来回答问题。"
"{context}"
"答案最多使用1句话并保持非常简洁不能换行。"
)
# instr = "你是生命周期领域富有经验和知识的专家。根据你所掌握的知识只用1句话回答问题。不要列出几点来回答不需要换行"
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
("human", "{input}"),
]
)
# Create a chain for passing a list of Documents to a model.
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def user_in(uin, rag_chain):
try:
result = rag_chain.invoke({"input": uin})['answer']
# print(rag_chain)
# result = rag_chain(uin)
return result
except Exception as e:
print(f"Error in user_in: {e}")
return "An error occurred while processing your request."
def main():
question = []
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/split/question/ques0.txt","r",encoding="utf-8") as file:
for line in file.readlines():
question.append(line.strip())
rag = init_chain()
answers = []
for ques in question:
# print(ques)
response = user_in(ques, rag)
print(response)
# print(len(response))
answers.append(response.strip())
# if len(answers) == 3:
# break
with open("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/RAGpred.txt","w",encoding="utf-8") as file:
for ans in answers:
line = re.sub(r'\s+', '', ans)
file.write(line+'\n')
data = {"ans":answers}
df = pd.DataFrame(data)
df.to_csv("/home/zhangxj/WorkFile/LCA-GPT/QA/eval/rag.csv",index=False)
if __name__ == "__main__":
main()