From 08d77eb79eac905135793f3a73708485ba24910c Mon Sep 17 00:00:00 2001 From: chenxiaodong Date: Mon, 17 Feb 2025 09:29:08 +0800 Subject: [PATCH] first infer --- inference.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++ models/env.py | 24 +++++++--- 2 files changed, 139 insertions(+), 7 deletions(-) diff --git a/inference.py b/inference.py index e69de29..aaa5ed0 100644 --- a/inference.py +++ b/inference.py @@ -0,0 +1,122 @@ +import queue +import threading +import time + +import torch + +from train import * + + +def test_one_step(env, act, device, data, action_path): + env.rec_data = data + state = env.reset() + s_tensor = torch.as_tensor((state,), device=device) + a_tensor = act(s_tensor) + action = a_tensor.detach().cpu().numpy()[0] + state, next_state, reward, done = env.step(action) + print(f'The action of {env.current_time} is {action}') + + with open(action_path, 'a') as af: + af.write(f'{env.current_time},{action}\n') + return reward, env.real_unbalance + + +def run_service_test(env, agent, data): + service_result_path = 'data/service_result.csv' + action_path = 'data/service_actions.csv' + + if not os.path.exists(service_result_path): + with open(service_result_path, 'w') as f: + f.write('reward,unbalance\n') + + if not os.path.exists(action_path): + with open(action_path, 'w') as af: + af.write('time,action\n') + + service_rewards = [] + service_unbalances = [] + + service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path) + service_rewards.append(service_reward) + service_unbalances.append(service_unbalance) + + if service_rewards: + avg_reward = sum(service_rewards) / len(service_rewards) + avg_unbalance = sum(service_unbalances) / len(service_unbalances) + + with open(service_result_path, 'a') as f: + f.write(f'{avg_reward},{avg_unbalance}\n') + + +# 接听端 +def listener_thread(env, agent, data_queue): + while True: + time.sleep(0.1) # 等待 + if not data_queue.empty(): + new_data = data_queue.get() + print(f"Data received: {new_data}") + run_service_test(env, agent, new_data) + data_queue.task_done() + + +# 发送端 +def sender_thread(data_queue): + while True: + try: + user_input = input("请输入price, temper, solar, load, heat, people(用逗号分隔): ") + + # 将输入字符串分割并转换为浮点数列表 + input_data = list(map(float, user_input.split(','))) + + # 检查输入是否包含六个数值 + if len(input_data) != 6: + print("输入数据格式不正确,请输入六个数值。") + continue + + # 将数据放入队列 + print(f"Sending data: {input_data}") + data_queue.put(input_data) + + except ValueError: + print("输入数据格式不正确,请输入数值。") + + +def main(): + args = Arguments() + args.visible_gpu = '0' + for seed in args.random_seed_list: + args.random_seed = seed + args.agent = AgentPPO() + args.agent.cri_target = True + agent_name = f'{args.agent.__class__.__name__}' + args.env = WgzGym() + args.init_before_training() + + agent = args.agent + env = args.env + env.TRAIN = False + agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate) + + if args.service: + args.cwd = agent_name + service_act_save_path = f'{args.cwd}/actor.pth' + agent.act.load_state_dict(torch.load(service_act_save_path)) + + # 创建一个队列用于线程间通信 + data_queue = queue.Queue() + + listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue)) + listener.daemon = True + listener.start() + + sender = threading.Thread(target=sender_thread, args=(data_queue,)) + sender.daemon = True + sender.start() + + # 主线程保持运行,等待数据传递 + while True: + time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/models/env.py b/models/env.py index 5530686..0258a3f 100644 --- a/models/env.py +++ b/models/env.py @@ -9,6 +9,7 @@ from parameters import * class WgzGym(gym.Env): def __init__(self, **kwargs): super(WgzGym, self).__init__() + self.rec_data = None self.unbalance = None self.reward = None self.current_output = None @@ -58,12 +59,20 @@ class WgzGym(gym.Env): grid_ex = self.grid time_step = self.current_time - price = self.data_manager.get_price_data(self.month, self.day, self.current_time) - temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time) - solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time) - load = self.data_manager.get_load_data(self.month, self.day, self.current_time) - heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time) - people = self.data_manager.get_people_data(self.month, self.day, self.current_time) + if self.TRAIN: + price = self.data_manager.get_price_data(self.month, self.day, self.current_time) + temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time) + solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time) + load = self.data_manager.get_load_data(self.month, self.day, self.current_time) + heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time) + people = self.data_manager.get_people_data(self.month, self.day, self.current_time) + else: + price = self.rec_data[0] + temper = self.rec_data[1] + solar = self.rec_data[2] + load = self.rec_data[3] + heat = self.rec_data[4] + people = self.rec_data[5] obs = np.concatenate((np.float32(time_step), np.float32(price), np.float32(temper), np.float32(solar), np.float32(load), np.float32(heat), @@ -78,10 +87,11 @@ class WgzGym(gym.Env): self.HST.step(action[1]) self.grid.step(action[2]) price = current_obs[1] - temper = current_obs[2] + temper = current_obs[2] # 用途待补充 solar = current_obs[3] load = current_obs[4] heat = current_obs[5] + people = current_obs[6] # 用途待补充 power_gap = solar + self.HST.get_power() - self.EC.current_power - load heat_gap = self.HST.get_heat() + self.EC.get_heat() - heat