LCA-LLM/LCA_RAG/UploadMain.py

172 lines
6.3 KiB
Python

'''
支持上传文件进行分析,记录历史上下文,
'''
import time
import streamlit as st
import tempfile
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
def chroma_save_upload(path):
try:
# Load the docs
file_type = os.path.basename(path).split('.')[1]
loader = None
doc = None
if file_type == "txt":
loader = TextLoader(path, encoding="utf-8")
elif file_type == "pdf":
loader = PyPDFLoader(path)
elif file_type == "csv":
loader == CSVLoader(path, encoding="utf-8")
elif file_type == "json":
json_data = requests.get(path).json()
splitter = RecursiveJsonSplitter(max_chunk_size=300)
doc = splitter.create_documents(texts=[json_data])
if doc is None:
doc = loader.load()
# Split the doc content
tex_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = tex_splitter.split_documents(doc)
# Store the content
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
vs = Chroma.from_documents(documents=splits, embedding=embedding, persist_directory="chromaDB")
vs.add_documents(documents=splits)
vs.as_retriever()
print("Upload Files saved: " + str(path))
except Exception as e:
print(f"Error in chroma_save_upload: {e}")
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
try:
# 读取上传的文档,并写入一个临时目录
temp_dir = tempfile.TemporaryDirectory(dir="/home/zhangxj/WorkFile/LCA-GPT/LCA_RAG/tmp")
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
print("文档路径:", temp_filepath)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
chroma_save_upload(path=temp_filepath)
except Exception as e:
print(f"Error in configure_retriever: {e}")
def init_chain():
try:
# 加载环境变量并初始化模型
load_dotenv(".env")
os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
llm = ChatTongyi(
streaming=True,
model='qwen-plus'
)
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding)
retriever = retriever.as_retriever()
instruct_system_prompt = (
"你是生命周期领域富有经验的专家。"
"你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
"如果你有不明白的地方,请向用户询问。"
"涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n"
"{context}"
)
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Create a chain for passing a list of Documents to a model.
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def user_in(uin, rag_chain, history):
try:
result = rag_chain.invoke({"input": uin, "chat_history": history})["answer"]
print(result)
return result
except Exception as e:
print(f"Error in user_in: {e}")
return "An error occurred while processing your request."
def get_code(text):
return
def main(memory=memory):
start_time = time.time()
st.set_page_config(page_title="LCA-GPT", layout="wide")
st.title("LCA-GPT")
uploaded_files = st.sidebar.file_uploader(
label="Upload files", type=None, accept_multiple_files=True
)
if uploaded_files:
configure_retriever(uploaded_files)
if "messages" not in st.session_state or st.sidebar.button("Clear chat history"):
st.session_state["messages"] = [{"role": "assistant", "content": "Hello, I am LCA-GPT, helping you solve problems in the LCA field"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "rag" not in st.session_state:
st.session_state.rag = init_chain()
print("Prompt start ......")
if prompt := st.chat_input("Please enter your question:"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
st.session_state.chat_history.extend([HumanMessage(content=prompt), AIMessage(content=response)])
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
if __name__ == "__main__":
main()