114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
|
import streamlit as st
|
|||
|
from dotenv import load_dotenv
|
|||
|
from langchain_core.prompts import ChatPromptTemplate
|
|||
|
from langchain.prompts import MessagesPlaceholder
|
|||
|
from langchain_core.messages import AIMessage, HumanMessage
|
|||
|
from langchain.chains import create_retrieval_chain
|
|||
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|||
|
from langchain_openai import ChatOpenAI
|
|||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|||
|
from langchain_chroma import Chroma
|
|||
|
|
|||
|
# 记住最近的6个聊天元素
|
|||
|
memory = 6
|
|||
|
|
|||
|
def init_chain():
|
|||
|
"""
|
|||
|
初始化聊天模型和检索链。
|
|||
|
|
|||
|
此函数设置环境变量,加载OpenAI模型,创建ChromaDB检索器,并编译用于问题回答的最终链。
|
|||
|
|
|||
|
返回:
|
|||
|
Runnable Sequence: 初始化的检索链。
|
|||
|
"""
|
|||
|
|
|||
|
# 加载环境变量并初始化模型
|
|||
|
load_dotenv(".env")
|
|||
|
key = 1
|
|||
|
llm = ChatOpenAI(streaming=True,
|
|||
|
verbose=True,
|
|||
|
api_key=key,
|
|||
|
openai_api_base="http://localhost:8000/v1",
|
|||
|
model="Qwen1.5-32b-int4")
|
|||
|
|
|||
|
embedding = HuggingFaceEmbeddings(model_name = "/home/zhangxj/models/acge_text_embedding")
|
|||
|
# 通过加载ChromaDB创建一个数据库检索器
|
|||
|
retriever = Chroma(persist_directory="chroma-laws", embedding_function=embedding)
|
|||
|
retriever = retriever.as_retriever()
|
|||
|
|
|||
|
# 指令模板
|
|||
|
instruct_system_prompt = (
|
|||
|
"你是生命周期领域富有经验的专家。"
|
|||
|
"利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
|
|||
|
"如果你不明白,请询问你不清楚的部分。"
|
|||
|
"尽量少用句子,但如果有必要可以完整引用文献资料。"
|
|||
|
"尽量保持对话有趣和机智。\n\n"
|
|||
|
"{context}")
|
|||
|
|
|||
|
instruct_prompt = ChatPromptTemplate.from_messages(
|
|||
|
[
|
|||
|
("system", instruct_system_prompt),
|
|||
|
MessagesPlaceholder("chat_history"),
|
|||
|
("human", "{input}"),
|
|||
|
])
|
|||
|
|
|||
|
# 编译最终链并返回
|
|||
|
qa_chain = create_stuff_documents_chain(llm, instruct_prompt)
|
|||
|
rag_chain = create_retrieval_chain(retriever, qa_chain)
|
|||
|
|
|||
|
return rag_chain
|
|||
|
|
|||
|
def user_in(uin, rag_chain, history):
|
|||
|
"""
|
|||
|
返回GPT生成的输出。
|
|||
|
|
|||
|
参数:
|
|||
|
uin (str): 用户输入。
|
|||
|
rag_chain (Runnable Sequence): llm链。
|
|||
|
history (list): 聊天历史列表。
|
|||
|
返回:
|
|||
|
str: 生成的响应
|
|||
|
"""
|
|||
|
result = rag_chain.invoke({"input": uin, "chat_history" : history})["answer"]
|
|||
|
|
|||
|
# 用新问题和响应更新聊天历史
|
|||
|
# history.extend([HumanMessage(content=uin),AIMessage(content=result)])
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def main(memory=memory):
|
|||
|
st.title("LCA-GPT")
|
|||
|
|
|||
|
# 创建并存储在streamlit缓存中
|
|||
|
if "chat_history" not in st.session_state:
|
|||
|
st.session_state.chat_history = []
|
|||
|
|
|||
|
if "rag" not in st.session_state:
|
|||
|
st.session_state.rag = init_chain()
|
|||
|
|
|||
|
if "messages" not in st.session_state:
|
|||
|
st.session_state.messages = []
|
|||
|
|
|||
|
# 更新聊天消息
|
|||
|
for message in st.session_state.messages:
|
|||
|
with st.chat_message(message["role"]):
|
|||
|
st.markdown(message["content"])
|
|||
|
|
|||
|
if prompt := st.chat_input("请输入你的问题:"):
|
|||
|
|
|||
|
with st.chat_message("user"):
|
|||
|
st.markdown(prompt)
|
|||
|
|
|||
|
st.session_state.messages.append({"role":"user", "content": prompt})
|
|||
|
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
|
|||
|
st.session_state.chat_history.extend([HumanMessage(content=prompt),AIMessage(content=response)])
|
|||
|
|
|||
|
with st.chat_message("assistant"):
|
|||
|
st.markdown(response)
|
|||
|
st.session_state.messages.append({"role" : "assistant", "content": response})
|
|||
|
print("历史:",st.session_state.chat_history[-memory:])
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|