LCA-LLM/LCA_RAG/csv_agent.py

121 lines
4.5 KiB
Python
Raw Permalink Normal View History

2025-02-06 16:10:41 +08:00
from langchain.agents import AgentType
from langchain_experimental.agents import create_csv_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import ChatZhipuAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import re
import pandas as pd
import os
import streamlit as st
from dotenv import load_dotenv
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
def get_retriever():
try:
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 5})
return retriever
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def retrieve_and_output(retriever,query):
try:
retrieved_docs = retriever.get_relevant_documents(query)
# 输出检索到的内容
res = ''
for doc in retrieved_docs:
contents = doc.page_content
cleaned_contents = re.sub(r"\s+", "", contents)
# print("###",cleaned_contents,"###")
res = res+cleaned_contents
return res
except Exception as e:
print(f"Error in retrieve_and_output: {e}")
st.set_page_config(page_title="Chat with your csv!!" , page_icon="random")
st.title(":male-student: :book: Chat with your csv!!")
uploaded_file = st.file_uploader(
"请上传你需要分析的数据" ,
type = "csv" ,
help = "你需要上传的格式为csv"
)
if not uploaded_file:
st.warning("您必须上传一个文件从而进行数据分析")
if "messages" not in st.session_state or st.sidebar.button("Clear conversation history"):
st.session_state['messages'] = [{"role" : "assistant" , "content" : "How can i help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if query := st.chat_input(placeholder="What is this data about?"):
st.session_state.messages.append( {"role" : "user" , "content" : query})
st.chat_message("user").write(query)
instruct_system_prompt = (
"你是生命周期领域富有经验的专家。"
"你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
"如果你有不明白的地方,请向用户询问。"
"涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n"
"{context}"
)
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
("human", "{input}"),
]
)
llm = ChatOpenAI(
temperature = 0 ,
model = 'qwen-plus',
api_key = os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量请用百炼API Key将本行替换为api_key="sk-xxx"
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
retriever= get_retriever()
context = retrieve_and_output(retriever,query)
messages = instruct_prompt.format_messages(context=context,input = query)
csv_agent = create_csv_agent(
llm ,
uploaded_file,
agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
allow_dangerous_code=True,
agent_executor_kwargs={"handle_parsing_errors": True}
)
with st.chat_message("assistant"):
st_cb = StreamlitCallbackHandler(st.container())
response = csv_agent.invoke(messages)
st.session_state.messages.append({"role":"assistant" , "content":response})
st.markdown(response['output'])