LCA-LLM/LCArag/main.py

114 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# 记住最近的6个聊天元素
memory = 6
def init_chain():
"""
初始化聊天模型和检索链。
此函数设置环境变量加载OpenAI模型创建ChromaDB检索器并编译用于问题回答的最终链。
返回:
Runnable Sequence: 初始化的检索链。
"""
# 加载环境变量并初始化模型
load_dotenv(".env")
key = 1
llm = ChatOpenAI(streaming=True,
verbose=True,
api_key=key,
openai_api_base="http://localhost:8000/v1",
model="Qwen1.5-32b-int4")
embedding = HuggingFaceEmbeddings(model_name = "/home/zhangxj/models/acge_text_embedding")
# 通过加载ChromaDB创建一个数据库检索器
retriever = Chroma(persist_directory="chroma-laws", embedding_function=embedding)
retriever = retriever.as_retriever()
# 指令模板
instruct_system_prompt = (
"你是生命周期领域富有经验的专家。"
"利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
"如果你不明白,请询问你不清楚的部分。"
"尽量少用句子,但如果有必要可以完整引用文献资料。"
"尽量保持对话有趣和机智。\n\n"
"{context}")
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
# 编译最终链并返回
qa_chain = create_stuff_documents_chain(llm, instruct_prompt)
rag_chain = create_retrieval_chain(retriever, qa_chain)
return rag_chain
def user_in(uin, rag_chain, history):
"""
返回GPT生成的输出。
参数:
uin (str): 用户输入。
rag_chain (Runnable Sequence): llm链。
history (list): 聊天历史列表。
返回:
str: 生成的响应
"""
result = rag_chain.invoke({"input": uin, "chat_history" : history})["answer"]
# 用新问题和响应更新聊天历史
# history.extend([HumanMessage(content=uin),AIMessage(content=result)])
return result
def main(memory=memory):
st.title("LCA-GPT")
# 创建并存储在streamlit缓存中
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "rag" not in st.session_state:
st.session_state.rag = init_chain()
if "messages" not in st.session_state:
st.session_state.messages = []
# 更新聊天消息
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("请输入你的问题:"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role":"user", "content": prompt})
response = user_in(prompt, st.session_state.rag, st.session_state.chat_history[-memory:])
st.session_state.chat_history.extend([HumanMessage(content=prompt),AIMessage(content=response)])
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role" : "assistant", "content": response})
print("历史:",st.session_state.chat_history[-memory:])
if __name__ == "__main__":
main()