''' 支持上传文件进行分析,记录历史上下文, ''' import time import streamlit as st import tempfile from dotenv import load_dotenv from langchain_core.prompts import ChatPromptTemplate from langchain.prompts import MessagesPlaceholder from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader from langchain_core.messages import AIMessage, HumanMessage from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_openai import ChatOpenAI from langchain_community.chat_models.tongyi import ChatTongyi from langchain_community.chat_models import ChatZhipuAI from langchain_community.embeddings import HuggingFaceEmbeddings # from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter import requests import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = "cuda" memory = 6 def chroma_save_upload(path): try: # Load the docs file_type = os.path.basename(path).split('.')[1] loader = None doc = None if file_type == "txt": loader = TextLoader(path, encoding="utf-8") elif file_type == "pdf": loader = PyPDFLoader(path) elif file_type == "csv": loader == CSVLoader(path, encoding="utf-8") elif file_type == "json": json_data = requests.get(path).json() splitter = RecursiveJsonSplitter(max_chunk_size=300) doc = splitter.create_documents(texts=[json_data]) if doc is None: doc = loader.load() # Split the doc content tex_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = tex_splitter.split_documents(doc) # Store the content embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device}) vs = Chroma.from_documents(documents=splits, embedding=embedding, persist_directory="chromaDB") vs.add_documents(documents=splits) vs.as_retriever() print("Upload Files saved: " + str(path)) except Exception as e: print(f"Error in chroma_save_upload: {e}") @st.cache_resource(ttl="1h") def configure_retriever(uploaded_files): try: # 读取上传的文档,并写入一个临时目录 temp_dir = tempfile.TemporaryDirectory(dir="/home/zhangxj/WorkFile/LCA-GPT/LCA_RAG/tmp") for file in uploaded_files: temp_filepath = os.path.join(temp_dir.name, file.name) print("文档路径:", temp_filepath) with open(temp_filepath, "wb") as f: f.write(file.getvalue()) chroma_save_upload(path=temp_filepath) except Exception as e: print(f"Error in configure_retriever: {e}") def init_chain(): try: # 加载环境变量并初始化模型 load_dotenv(".env") os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002' llm = ChatTongyi( streaming=True, model='qwen-plus' ) embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device}) retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding) retriever = retriever.as_retriever() instruct_system_prompt = ( "你是生命周期领域富有经验的专家。" "你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。" "如果你有不明白的地方,请向用户询问。" "涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n" "{context}" ) instruct_prompt = ChatPromptTemplate.from_messages( [ ("system", instruct_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # Create a chain for passing a list of Documents to a model. qa_chain = create_stuff_documents_chain(llm, instruct_prompt) # rag_chain = create_retrieval_chain(retriever, qa_chain) return rag_chain except Exception as e: print(f"Error in init_chain: {e}") return None def user_in(uin, rag_chain, history): try: result = rag_chain.invoke({"input": uin, "chat_history": history})["answer"] print(result) return result except Exception as e: print(f"Error in user_in: {e}") return "An error occurred while processing your request." def main(memory=memory): start_time = time.time() st.set_page_config(page_title="LCA-GPT", layout="wide") st.title("LCA-GPT") uploaded_files = st.sidebar.file_uploader( label="Upload files", type=None, accept_multiple_files=True ) if uploaded_files: configure_retriever(uploaded_files) if "messages" not in st.session_state or st.sidebar.button("Clear chat history"): st.session_state["messages"] = [{"role": "assistant", "content": "Hello, I am LCA-GPT, helping you solve problems in the LCA field"}] for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "rag" not in st.session_state: st.session_state.rag = init_chain() print("Prompt start ......") if prompt := st.chat_input("Please enter your question:"): with st.chat_message("user"): st.markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:]) st.session_state.chat_history.extend([HumanMessage(content=prompt), AIMessage(content=response)]) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") if __name__ == "__main__": main()