wgz_decision/inference.py

123 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import queue
import threading
import time
import torch
from train import *
def test_one_step(env, act, device, data, action_path):
env.rec_data = data
state = env.reset()
s_tensor = torch.as_tensor((state,), device=device)
a_tensor = act(s_tensor)
action = a_tensor.detach().cpu().numpy()[0]
state, next_state, reward, done = env.step(action)
print(f'The action of {env.current_time} is {action}')
with open(action_path, 'a') as af:
af.write(f'{env.current_time},{action}\n')
return reward, env.real_unbalance
def run_service_test(env, agent, data):
service_result_path = 'data/service_result.csv'
action_path = 'data/service_actions.csv'
if not os.path.exists(service_result_path):
with open(service_result_path, 'w') as f:
f.write('reward,unbalance\n')
if not os.path.exists(action_path):
with open(action_path, 'w') as af:
af.write('time,action\n')
service_rewards = []
service_unbalances = []
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
service_rewards.append(service_reward)
service_unbalances.append(service_unbalance)
if service_rewards:
avg_reward = sum(service_rewards) / len(service_rewards)
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
with open(service_result_path, 'a') as f:
f.write(f'{avg_reward},{avg_unbalance}\n')
# 接听端
def listener_thread(env, agent, data_queue):
while True:
time.sleep(0.1) # 等待
if not data_queue.empty():
new_data = data_queue.get()
print(f"Data received: {new_data}")
run_service_test(env, agent, new_data)
data_queue.task_done()
# 发送端
def sender_thread(data_queue):
while True:
try:
user_input = input("请输入price, temper, solar, load, heat, people用逗号分隔: ")
# 将输入字符串分割并转换为浮点数列表
input_data = list(map(float, user_input.split(',')))
# 检查输入是否包含六个数值
if len(input_data) != 6:
print("输入数据格式不正确,请输入六个数值。")
continue
# 将数据放入队列
print(f"Sending data: {input_data}")
data_queue.put(input_data)
except ValueError:
print("输入数据格式不正确,请输入数值。")
def main():
args = Arguments()
args.visible_gpu = '0'
for seed in args.random_seed_list:
args.random_seed = seed
args.agent = AgentPPO()
args.agent.cri_target = True
agent_name = f'{args.agent.__class__.__name__}'
args.env = WgzGym()
args.init_before_training()
agent = args.agent
env = args.env
env.TRAIN = False
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
if args.service:
args.cwd = agent_name
act_save_path = f'{args.cwd}/actor.pth'
agent.act.load_state_dict(torch.load(act_save_path))
# 创建一个队列用于线程间通信
data_queue = queue.Queue()
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
listener.daemon = True
listener.start()
sender = threading.Thread(target=sender_thread, args=(data_queue,))
sender.daemon = True
sender.start()
# 主线程保持运行,等待数据传递
while True:
time.sleep(10)
if __name__ == "__main__":
main()