LCA-LLM/LCA_RAG/csv_agent.py

121 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.agents import AgentType
from langchain_experimental.agents import create_csv_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain_community.document_loaders import TextLoader, PyPDFLoader, CSVLoader, JSONLoader
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import ChatZhipuAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
import requests
import re
import pandas as pd
import os
import streamlit as st
from dotenv import load_dotenv
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
memory = 6
os.environ["DASHSCOPE_API_KEY"] = 'sk-c5f441f863f44094b0ddb96c831b5002'
def get_retriever():
try:
embedding = HuggingFaceEmbeddings(model_name="/home/zhangxj/models/acge_text_embedding", model_kwargs={"device": device})
retriever = Chroma(persist_directory="chromaDB", embedding_function=embedding,)
retriever = retriever.as_retriever(search_type="mmr", search_kwargs={"k": 5})
return retriever
except Exception as e:
print(f"Error in init_chain: {e}")
return None
def retrieve_and_output(retriever,query):
try:
retrieved_docs = retriever.get_relevant_documents(query)
# 输出检索到的内容
res = ''
for doc in retrieved_docs:
contents = doc.page_content
cleaned_contents = re.sub(r"\s+", "", contents)
# print("###",cleaned_contents,"###")
res = res+cleaned_contents
return res
except Exception as e:
print(f"Error in retrieve_and_output: {e}")
st.set_page_config(page_title="Chat with your csv!!" , page_icon="random")
st.title(":male-student: :book: Chat with your csv!!")
uploaded_file = st.file_uploader(
"请上传你需要分析的数据" ,
type = "csv" ,
help = "你需要上传的格式为csv"
)
if not uploaded_file:
st.warning("您必须上传一个文件从而进行数据分析")
if "messages" not in st.session_state or st.sidebar.button("Clear conversation history"):
st.session_state['messages'] = [{"role" : "assistant" , "content" : "How can i help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if query := st.chat_input(placeholder="What is this data about?"):
st.session_state.messages.append( {"role" : "user" , "content" : query})
st.chat_message("user").write(query)
instruct_system_prompt = (
"你是生命周期领域富有经验的专家。"
"你要利用检索到的上下文来回答问题。如果上下文没有足够的信息,请说明。"
"如果你有不明白的地方,请向用户询问。"
"涉及生命后期评价领域的问题,你应该完整地引用文献资料。\n\n"
"{context}"
)
instruct_prompt = ChatPromptTemplate.from_messages(
[
("system", instruct_system_prompt),
("human", "{input}"),
]
)
llm = ChatOpenAI(
temperature = 0 ,
model = 'qwen-plus',
api_key = os.getenv("DASHSCOPE_API_KEY"), # 如果您没有配置环境变量请用百炼API Key将本行替换为api_key="sk-xxx"
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
retriever= get_retriever()
context = retrieve_and_output(retriever,query)
messages = instruct_prompt.format_messages(context=context,input = query)
csv_agent = create_csv_agent(
llm ,
uploaded_file,
agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
allow_dangerous_code=True,
agent_executor_kwargs={"handle_parsing_errors": True}
)
with st.chat_message("assistant"):
st_cb = StreamlitCallbackHandler(st.container())
response = csv_agent.invoke(messages)
st.session_state.messages.append({"role":"assistant" , "content":response})
st.markdown(response['output'])