172 lines
6.3 KiB
Python
172 lines
6.3 KiB
Python
'''
|
|
支持上传文件进行分析,记录历史上下文,
|
|
'''
|
|
import time
|
|
import streamlit as st
|
|
import tempfile
|
|
from dotenv import load_dotenv
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain.prompts import MessagesPlaceholder
|
|
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
from langchain.chains import create_retrieval_chain
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
|
from langchain_community.chat_models import ChatZhipuAI
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
# from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_chroma import Chroma
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
|
import requests
|
|
import os
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
device = "cuda"
|
|
memory = 6
|
|
|
|
|
|
def chroma_save_upload(path):
|
|
try:
|
|
# Load the docs
|
|
file_type = os.path.basename(path).split('.')[1]
|
|
loader = None
|
|
doc = None
|
|
if file_type == "txt":
|
|
loader = TextLoader(path, encoding="utf-8")
|
|
elif file_type == "pdf":
|
|
loader = PyPDFLoader(path)
|
|
elif file_type == "csv":
|
|
loader == CSVLoader(path, encoding="utf-8")
|
|
elif file_type == "json":
|
|
json_data = requests.get(path).json()
|
|
splitter = RecursiveJsonSplitter(max_chunk_size=300)
|
|
doc = splitter.create_documents(texts=[json_data])
|
|
|
|
if doc is None:
|
|
doc = loader.load()
|
|
|
|
# Split the doc content
|
|
tex_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
splits = tex_splitter.split_documents(doc)
|
|
|
|
# Store the content
|
|
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
|
|
vs = Chroma.from_documents(documents=splits, embedding=embedding, persist_directory="chromaDB")
|
|
vs.add_documents(documents=splits)
|
|
vs.as_retriever()
|
|
print("Upload Files saved: " + str(path))
|
|
except Exception as e:
|
|
print(f"Error in chroma_save_upload: {e}")
|
|
|
|
@st.cache_resource(ttl="1h")
|
|
def configure_retriever(uploaded_files):
|
|
try:
|
|
# 读取上传的文档,并写入一个临时目录
|
|
temp_dir = tempfile.TemporaryDirectory(dir="/home/zhangxj/WorkFile/LCA-GPT/LCA_RAG/tmp")
|
|
for file in uploaded_files:
|
|
temp_filepath = os.path.join(temp_dir.name, file.name)
|
|
print("文档路径:", temp_filepath)
|
|
with open(temp_filepath, "wb") as f:
|
|
f.write(file.getvalue())
|
|
chroma_save_upload(path=temp_filepath)
|
|
except Exception as e:
|
|
print(f"Error in configure_retriever: {e}")
|
|
|
|
def init_chain():
|
|
try:
|
|
# 加载环境变量并初始化模型
|
|
load_dotenv(".env")
|
|
os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
|
|
llm = ChatTongyi(
|
|
streaming=True,
|
|
model='qwen-plus'
|
|
)
|
|
|
|
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
|
|
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding)
|
|
retriever = retriever.as_retriever()
|
|
|
|
instruct_system_prompt = (
|
|
"你是生命周期领域富有经验的专家。"
|
|
"你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
|
|
"如果你有不明白的地方,请向用户询问。"
|
|
"涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n"
|
|
"{context}"
|
|
)
|
|
|
|
instruct_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", instruct_system_prompt),
|
|
MessagesPlaceholder("chat_history"),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
|
|
# Create a chain for passing a list of Documents to a model.
|
|
qa_chain = create_stuff_documents_chain(llm, instruct_prompt) #
|
|
rag_chain = create_retrieval_chain(retriever, qa_chain)
|
|
return rag_chain
|
|
except Exception as e:
|
|
print(f"Error in init_chain: {e}")
|
|
return None
|
|
|
|
def user_in(uin, rag_chain, history):
|
|
try:
|
|
result = rag_chain.invoke({"input": uin, "chat_history": history})["answer"]
|
|
print(result)
|
|
return result
|
|
except Exception as e:
|
|
print(f"Error in user_in: {e}")
|
|
return "An error occurred while processing your request."
|
|
|
|
def get_code(text):
|
|
|
|
return
|
|
|
|
def main(memory=memory):
|
|
start_time = time.time()
|
|
st.set_page_config(page_title="LCA-GPT", layout="wide")
|
|
st.title("LCA-GPT")
|
|
|
|
uploaded_files = st.sidebar.file_uploader(
|
|
label="Upload files", type=None, accept_multiple_files=True
|
|
)
|
|
|
|
if uploaded_files:
|
|
configure_retriever(uploaded_files)
|
|
|
|
if "messages" not in st.session_state or st.sidebar.button("Clear chat history"):
|
|
st.session_state["messages"] = [{"role": "assistant", "content": "Hello, I am LCA-GPT, helping you solve problems in the LCA field"}]
|
|
|
|
for msg in st.session_state.messages:
|
|
st.chat_message(msg["role"]).write(msg["content"])
|
|
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state.chat_history = []
|
|
|
|
if "rag" not in st.session_state:
|
|
st.session_state.rag = init_chain()
|
|
|
|
print("Prompt start ......")
|
|
if prompt := st.chat_input("Please enter your question:"):
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
|
|
|
|
st.session_state.chat_history.extend([HumanMessage(content=prompt), AIMessage(content=response)])
|
|
with st.chat_message("assistant"):
|
|
st.markdown(response)
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
end_time = time.time()
|
|
execution_time = end_time - start_time
|
|
print(f"Execution time: {execution_time} seconds")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|