LCA-GPT/LCArag/main.py

114 lines
3.9 KiB
Python
Raw Normal View History

2024-07-30 10:56:08 +08:00
import streamlit as st
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# 记住最近的6个聊天元素
memory = 6
def init_chain():
"""
初始化聊天模型和检索链
此函数设置环境变量加载OpenAI模型创建ChromaDB检索器并编译用于问题回答的最终链
返回
Runnable Sequence: 初始化的检索链
"""
# 加载环境变量并初始化模型
load_dotenv(".env")
key = 1
llm = ChatOpenAI(streaming=True,
verbose=True,
api_key=key,
openai_api_base="http://localhost:8000/v1",
model="Qwen1.5-32b-int4")
embedding = HuggingFaceEmbeddings(model_name = "/home/zhangxj/models/acge_text_embedding")
# 通过加载ChromaDB创建一个数据库检索器
retriever = Chroma(persist_directory="chroma-laws", embedding_function=embedding)
retriever = retriever.as_retriever()
# 指令模板
instruct_system_prompt = (
"你是生命周期领域富有经验的专家。"
"利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
"如果你不明白,请询问你不清楚的部分。"
"尽量少用句子,但如果有必要可以完整引用文献资料。"
"尽量保持对话有趣和机智。\n\n"
"{context}")
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
# 编译最终链并返回
qa_chain = create_stuff_documents_chain(llm, instruct_prompt)
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain
def user_in(uin, rag_chain, history):
"""
返回GPT生成的输出
参数:
uin (str): 用户输入
rag_chain (Runnable Sequence): llm链
history (list): 聊天历史列表
返回:
str: 生成的响应
"""
result = rag_chain.invoke({"input": uin, "chat_history" : history})["answer"]
# 用新问题和响应更新聊天历史
# history.extend([HumanMessage(content=uin),AIMessage(content=result)])
return result
def main(memory=memory):
st.title("LCA-GPT")
# 创建并存储在streamlit缓存中
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "rag" not in st.session_state:
st.session_state.rag = init_chain()
if "messages" not in st.session_state:
st.session_state.messages = []
# 更新聊天消息
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("请输入你的问题:"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role":"user", "content": prompt})
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
st.session_state.chat_history.extend([HumanMessage(content=prompt),AIMessage(content=response)])
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role" : "assistant", "content": response})
print("历史:",st.session_state.chat_history[-memory:])
if __name__ == "__main__":
main()