114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
import streamlit as st
|
||
from dotenv import load_dotenv
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain.prompts import MessagesPlaceholder
|
||
from langchain_core.messages import AIMessage, HumanMessage
|
||
from langchain.chains import create_retrieval_chain
|
||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain.embeddings import HuggingFaceEmbeddings
|
||
from langchain_chroma import Chroma
|
||
|
||
# 记住最近的6个聊天元素
|
||
memory = 6
|
||
|
||
def init_chain():
|
||
"""
|
||
初始化聊天模型和检索链。
|
||
|
||
此函数设置环境变量,加载OpenAI模型,创建ChromaDB检索器,并编译用于问题回答的最终链。
|
||
|
||
返回:
|
||
Runnable Sequence: 初始化的检索链。
|
||
"""
|
||
|
||
# 加载环境变量并初始化模型
|
||
load_dotenv(".env")
|
||
key = 1
|
||
llm = ChatOpenAI(streaming=True,
|
||
verbose=True,
|
||
api_key=key,
|
||
openai_api_base="http://localhost:8000/v1",
|
||
model="Qwen1.5-32b-int4")
|
||
|
||
embedding = HuggingFaceEmbeddings(model_name = "/home/zhangxj/models/acge_text_embedding")
|
||
# 通过加载ChromaDB创建一个数据库检索器
|
||
retriever = Chroma(persist_directory="chroma-laws", embedding_function=embedding)
|
||
retriever = retriever.as_retriever()
|
||
|
||
# 指令模板
|
||
instruct_system_prompt = (
|
||
"你是生命周期领域富有经验的专家。"
|
||
"利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
|
||
"如果你不明白,请询问你不清楚的部分。"
|
||
"尽量少用句子,但如果有必要可以完整引用文献资料。"
|
||
"尽量保持对话有趣和机智。\n\n"
|
||
"{context}")
|
||
|
||
instruct_prompt = ChatPromptTemplate.from_messages(
|
||
[
|
||
("system", instruct_system_prompt),
|
||
MessagesPlaceholder("chat_history"),
|
||
("human", "{input}"),
|
||
])
|
||
|
||
# 编译最终链并返回
|
||
qa_chain = create_stuff_documents_chain(llm, instruct_prompt)
|
||
rag_chain = create_retrieval_chain(retriever, qa_chain)
|
||
|
||
return rag_chain
|
||
|
||
def user_in(uin, rag_chain, history):
|
||
"""
|
||
返回GPT生成的输出。
|
||
|
||
参数:
|
||
uin (str): 用户输入。
|
||
rag_chain (Runnable Sequence): llm链。
|
||
history (list): 聊天历史列表。
|
||
返回:
|
||
str: 生成的响应
|
||
"""
|
||
result = rag_chain.invoke({"input": uin, "chat_history" : history})["answer"]
|
||
|
||
# 用新问题和响应更新聊天历史
|
||
# history.extend([HumanMessage(content=uin),AIMessage(content=result)])
|
||
return result
|
||
|
||
|
||
def main(memory=memory):
|
||
st.title("LCA-GPT")
|
||
|
||
# 创建并存储在streamlit缓存中
|
||
if "chat_history" not in st.session_state:
|
||
st.session_state.chat_history = []
|
||
|
||
if "rag" not in st.session_state:
|
||
st.session_state.rag = init_chain()
|
||
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
|
||
# 更新聊天消息
|
||
for message in st.session_state.messages:
|
||
with st.chat_message(message["role"]):
|
||
st.markdown(message["content"])
|
||
|
||
if prompt := st.chat_input("请输入你的问题:"):
|
||
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
|
||
st.session_state.messages.append({"role":"user", "content": prompt})
|
||
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
|
||
st.session_state.chat_history.extend([HumanMessage(content=prompt),AIMessage(content=response)])
|
||
|
||
with st.chat_message("assistant"):
|
||
st.markdown(response)
|
||
st.session_state.messages.append({"role" : "assistant", "content": response})
|
||
print("历史:",st.session_state.chat_history[-memory:])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|