74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
import pandas as pd
|
||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
from langchain_openai.chat_models import ChatOpenAI
|
||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||
from langchain_core.chat_history import BaseChatMessageHistory
|
||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||
|
||
|
||
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
||
if session_id not in store:
|
||
store[session_id] = ChatMessageHistory()
|
||
return store[session_id]
|
||
|
||
|
||
def read_txt_file(file_path: str) -> str:
|
||
with open(file_path, "r", encoding='ISO-8859-1') as file:
|
||
content = file.read()
|
||
return content
|
||
|
||
|
||
def read_csv_file(file_path: str) -> pd.DataFrame:
|
||
return pd.read_csv(file_path)
|
||
|
||
|
||
def format_csv_data(df: pd.DataFrame, start: int, end: int) -> str:
|
||
formatted_data = []
|
||
for index, row in df.iloc[start:end].iterrows():
|
||
formatted_data.append(
|
||
f"时刻{index + 1}: Sp={row['price']}, Sl={row['load']},"
|
||
f"St={row['temperature']}, Si={row['irradiance']}, Sw={row['wind_speed']}"
|
||
)
|
||
return "\n".join(formatted_data)
|
||
|
||
|
||
system_content = read_txt_file('./llm.txt')
|
||
df = read_csv_file('./data.csv')
|
||
llm = ChatOpenAI(
|
||
streaming=True,
|
||
verbose=True,
|
||
openai_api_key="none",
|
||
# openai_api_base="http://0.0.0.0:5049/v1/models",
|
||
openai_api_base="http://localhost:8000/v1",
|
||
model_name="Qwen1.5-32b-int4"
|
||
)
|
||
prompt = ChatPromptTemplate.from_messages(
|
||
[
|
||
("system", system_content,),
|
||
MessagesPlaceholder(variable_name="history"),
|
||
("human", "{input}"),
|
||
]
|
||
)
|
||
runnable = prompt | llm
|
||
store = {}
|
||
with_message_history = RunnableWithMessageHistory(
|
||
runnable,
|
||
get_session_history,
|
||
input_messages_key="input",
|
||
history_messages_key="history",
|
||
)
|
||
|
||
num_hours = len(df)
|
||
for i in range(num_hours):
|
||
start = i
|
||
end = start + 1
|
||
csv_data_chunk = format_csv_data(df, start, end)
|
||
|
||
result = with_message_history.invoke(
|
||
{"input": f"数据如下:\n{csv_data_chunk}\n只返回json格式的五个决策数据:{{[x1 x2 x3 x4 x5]}}"},
|
||
config={
|
||
"configurable": {"session_id": "cxd"}
|
||
},
|
||
)
|
||
print(f"Hour {i + 1}:\n{result.content}\n")
|