import queue import threading import time import torch from train import * def test_one_step(env, act, device, data, action_path): env.rec_data = data state = env.reset() s_tensor = torch.as_tensor((state,), device=device) a_tensor = act(s_tensor) action = a_tensor.detach().cpu().numpy()[0] state, next_state, reward, done = env.step(action) print(f'The action of {env.current_time} is {action}') with open(action_path, 'a') as af: af.write(f'{env.current_time},{action}\n') return reward, env.unbalance def run_service_test(env, agent, data): service_result_path = 'data/service_result.csv' action_path = 'data/service_actions.csv' if not os.path.exists(service_result_path): with open(service_result_path, 'w') as f: f.write('reward,unbalance\n') if not os.path.exists(action_path): with open(action_path, 'w') as af: af.write('time,action\n') service_rewards = [] service_unbalances = [] service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path) service_rewards.append(service_reward) service_unbalances.append(service_unbalance) if service_rewards: avg_reward = sum(service_rewards) / len(service_rewards) avg_unbalance = sum(service_unbalances) / len(service_unbalances) with open(service_result_path, 'a') as f: f.write(f'{avg_reward},{avg_unbalance}\n') # 接听端 def listener_thread(env, agent, data_queue): while True: time.sleep(0.1) # 等待 if not data_queue.empty(): new_data = data_queue.get() print(f"Data received: {new_data}") run_service_test(env, agent, new_data) data_queue.task_done() # 发送端 def sender_thread(data_queue): while True: try: time.sleep(0.5) user_input = input("请输入当前时刻的price, temper, solar, load, heat, people(用逗号分隔): \n") # 将输入字符串分割并转换为浮点数列表 input_data = list(map(float, user_input.split(','))) # 检查输入是否包含六个数值 if len(input_data) != 6: print("输入格式不正确,请输入六个数值。") continue # 将数据放入队列 print(f"Sending data: {input_data}") data_queue.put(input_data) except ValueError: print("输入格式不正确,请输入数值。") def main(): args = Arguments() args.visible_gpu = '0' for seed in args.random_seed_list: args.random_seed = seed args.agent = AgentPPO() args.agent.cri_target = True args.env = WgzGym() args.init_before_training() agent = args.agent env = args.env env.TRAIN = False agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate) act_save_path = './data/actor.pth' agent.act.load_state_dict(torch.load(act_save_path)) # 创建一个队列用于线程间通信 data_queue = queue.Queue() listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue)) listener.daemon = True listener.start() sender = threading.Thread(target=sender_thread, args=(data_queue,)) sender.daemon = True sender.start() # 主线程保持运行,等待数据传递 while True: time.sleep(10) if __name__ == "__main__": main()