121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
from langchain.agents import AgentType
|
||
from langchain_experimental.agents import create_csv_agent
|
||
from langchain.callbacks import StreamlitCallbackHandler
|
||
from langchain.chat_models import ChatOpenAI
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain.prompts import MessagesPlaceholder
|
||
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
|
||
from langchain_core.messages import AIMessage, HumanMessage
|
||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||
from langchain_community.chat_models import QianfanChatEndpoint
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain.retrievers import ContextualCompressionRetriever
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
from langchain_chroma import Chroma
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||
import requests
|
||
import re
|
||
import pandas as pd
|
||
import os
|
||
import streamlit as st
|
||
from dotenv import load_dotenv
|
||
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
device = "cuda"
|
||
memory = 6
|
||
os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
|
||
|
||
def get_retriever():
|
||
try:
|
||
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
|
||
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
|
||
|
||
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 5})
|
||
return retriever
|
||
|
||
except Exception as e:
|
||
print(f"Error in init_chain: {e}")
|
||
return None
|
||
|
||
def retrieve_and_output(retriever,query):
|
||
try:
|
||
retrieved_docs = retriever.get_relevant_documents(query)
|
||
# 输出检索到的内容
|
||
res = ''
|
||
for doc in retrieved_docs:
|
||
contents = doc.page_content
|
||
cleaned_contents = re.sub(r"\s+", "", contents)
|
||
# print("###",cleaned_contents,"###")
|
||
res = res+cleaned_contents
|
||
return res
|
||
except Exception as e:
|
||
print(f"Error in retrieve_and_output: {e}")
|
||
|
||
|
||
st.set_page_config(page_title="Chat with your csv!!" , page_icon="random")
|
||
st.title(":male-student: :book: Chat with your csv!!")
|
||
|
||
uploaded_file = st.file_uploader(
|
||
"请上传你需要分析的数据" ,
|
||
type = "csv" ,
|
||
help = "你需要上传的格式为csv"
|
||
)
|
||
|
||
if not uploaded_file:
|
||
st.warning("您必须上传一个文件从而进行数据分析")
|
||
|
||
|
||
if "messages" not in st.session_state or st.sidebar.button("Clear conversation history"):
|
||
st.session_state['messages'] = [{"role" : "assistant" , "content" : "How can i help you?"}]
|
||
|
||
for msg in st.session_state.messages:
|
||
st.chat_message(msg["role"]).write(msg["content"])
|
||
|
||
if query := st.chat_input(placeholder="What is this data about?"):
|
||
st.session_state.messages.append( {"role" : "user" , "content" : query})
|
||
st.chat_message("user").write(query)
|
||
|
||
instruct_system_prompt = (
|
||
"你是生命周期领域富有经验的专家。"
|
||
"你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
|
||
"如果你有不明白的地方,请向用户询问。"
|
||
"涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n"
|
||
"{context}"
|
||
)
|
||
|
||
instruct_prompt = ChatPromptTemplate.from_messages(
|
||
[
|
||
("system", instruct_system_prompt),
|
||
("human", "{input}"),
|
||
]
|
||
)
|
||
|
||
llm = ChatOpenAI(
|
||
temperature = 0 ,
|
||
model = 'qwen-plus',
|
||
api_key = os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请用百炼API Key将本行替换为:api_key="sk-xxx"
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
retriever= get_retriever()
|
||
context = retrieve_and_output(retriever,query)
|
||
|
||
messages = instruct_prompt.format_messages(context=context,input = query)
|
||
|
||
csv_agent = create_csv_agent(
|
||
llm ,
|
||
uploaded_file,
|
||
agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||
allow_dangerous_code=True,
|
||
agent_executor_kwargs={"handle_parsing_errors": True}
|
||
)
|
||
|
||
with st.chat_message("assistant"):
|
||
st_cb = StreamlitCallbackHandler(st.container())
|
||
response = csv_agent.invoke(messages)
|
||
st.session_state.messages.append({"role":"assistant" , "content":response})
|
||
st.markdown(response['output']) |