import streamlit as st from dotenv import load_dotenv from langchain_core.prompts import ChatPromptTemplate from langchain.prompts import MessagesPlaceholder from langchain_core.messages import AIMessage, HumanMessage from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_openai import ChatOpenAI from langchain.embeddings import HuggingFaceEmbeddings from langchain_chroma import Chroma # 记住最近的6个聊天元素 memory = 6 def init_chain(): """ 初始化聊天模型和检索链。 此函数设置环境变量,加载OpenAI模型,创建ChromaDB检索器,并编译用于问题回答的最终链。 返回: Runnable Sequence: 初始化的检索链。 """ # 加载环境变量并初始化模型 load_dotenv(".env") key = 1 llm = ChatOpenAI(streaming=True, verbose=True, api_key=key, openai_api_base="http://localhost:8000/v1", model="Qwen1.5-32b-int4") embedding = HuggingFaceEmbeddings(model_name = "/home/zhangxj/models/acge_text_embedding") # 通过加载ChromaDB创建一个数据库检索器 retriever = Chroma(persist_directory="chroma-laws", embedding_function=embedding) retriever = retriever.as_retriever() # 指令模板 instruct_system_prompt = ( "你是生命周期领域富有经验的专家。" "利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。" "如果你不明白,请询问你不清楚的部分。" "尽量少用句子,但如果有必要可以完整引用文献资料。" "尽量保持对话有趣和机智。\n\n" "{context}") instruct_prompt = ChatPromptTemplate.from_messages( [ ("system", instruct_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ]) # 编译最终链并返回 qa_chain = create_stuff_documents_chain(llm, instruct_prompt) rag_chain = create_retrieval_chain(retriever, qa_chain) return rag_chain def user_in(uin, rag_chain, history): """ 返回GPT生成的输出。 参数: uin (str): 用户输入。 rag_chain (Runnable Sequence): llm链。 history (list): 聊天历史列表。 返回: str: 生成的响应 """ result = rag_chain.invoke({"input": uin, "chat_history" : history})["answer"] # 用新问题和响应更新聊天历史 # history.extend([HumanMessage(content=uin),AIMessage(content=result)]) return result def main(memory=memory): st.title("LCA-GPT") # 创建并存储在streamlit缓存中 if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "rag" not in st.session_state: st.session_state.rag = init_chain() if "messages" not in st.session_state: st.session_state.messages = [] # 更新聊天消息 for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("请输入你的问题:"): with st.chat_message("user"): st.markdown(prompt) st.session_state.messages.append({"role":"user", "content": prompt}) response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:]) st.session_state.chat_history.extend([HumanMessage(content=prompt),AIMessage(content=response)]) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role" : "assistant", "content": response}) print("历史:",st.session_state.chat_history[-memory:]) if __name__ == "__main__": main()