first infer

This commit is contained in:
chenxiaodong 2025-02-17 09:29:08 +08:00
parent fb89548c37
commit 08d77eb79e
2 changed files with 139 additions and 7 deletions

View File

@ -0,0 +1,122 @@
import queue
import threading
import time
import torch
from train import *
def test_one_step(env, act, device, data, action_path):
env.rec_data = data
state = env.reset()
s_tensor = torch.as_tensor((state,), device=device)
a_tensor = act(s_tensor)
action = a_tensor.detach().cpu().numpy()[0]
state, next_state, reward, done = env.step(action)
print(f'The action of {env.current_time} is {action}')
with open(action_path, 'a') as af:
af.write(f'{env.current_time},{action}\n')
return reward, env.real_unbalance
def run_service_test(env, agent, data):
service_result_path = 'data/service_result.csv'
action_path = 'data/service_actions.csv'
if not os.path.exists(service_result_path):
with open(service_result_path, 'w') as f:
f.write('reward,unbalance\n')
if not os.path.exists(action_path):
with open(action_path, 'w') as af:
af.write('time,action\n')
service_rewards = []
service_unbalances = []
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
service_rewards.append(service_reward)
service_unbalances.append(service_unbalance)
if service_rewards:
avg_reward = sum(service_rewards) / len(service_rewards)
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
with open(service_result_path, 'a') as f:
f.write(f'{avg_reward},{avg_unbalance}\n')
# 接听端
def listener_thread(env, agent, data_queue):
while True:
time.sleep(0.1) # 等待
if not data_queue.empty():
new_data = data_queue.get()
print(f"Data received: {new_data}")
run_service_test(env, agent, new_data)
data_queue.task_done()
# 发送端
def sender_thread(data_queue):
while True:
try:
user_input = input("请输入price, temper, solar, load, heat, people用逗号分隔: ")
# 将输入字符串分割并转换为浮点数列表
input_data = list(map(float, user_input.split(',')))
# 检查输入是否包含六个数值
if len(input_data) != 6:
print("输入数据格式不正确,请输入六个数值。")
continue
# 将数据放入队列
print(f"Sending data: {input_data}")
data_queue.put(input_data)
except ValueError:
print("输入数据格式不正确,请输入数值。")
def main():
args = Arguments()
args.visible_gpu = '0'
for seed in args.random_seed_list:
args.random_seed = seed
args.agent = AgentPPO()
args.agent.cri_target = True
agent_name = f'{args.agent.__class__.__name__}'
args.env = WgzGym()
args.init_before_training()
agent = args.agent
env = args.env
env.TRAIN = False
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
if args.service:
args.cwd = agent_name
service_act_save_path = f'{args.cwd}/actor.pth'
agent.act.load_state_dict(torch.load(service_act_save_path))
# 创建一个队列用于线程间通信
data_queue = queue.Queue()
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
listener.daemon = True
listener.start()
sender = threading.Thread(target=sender_thread, args=(data_queue,))
sender.daemon = True
sender.start()
# 主线程保持运行,等待数据传递
while True:
time.sleep(10)
if __name__ == "__main__":
main()

View File

@ -9,6 +9,7 @@ from parameters import *
class WgzGym(gym.Env):
def __init__(self, **kwargs):
super(WgzGym, self).__init__()
self.rec_data = None
self.unbalance = None
self.reward = None
self.current_output = None
@ -58,12 +59,20 @@ class WgzGym(gym.Env):
grid_ex = self.grid
time_step = self.current_time
price = self.data_manager.get_price_data(self.month, self.day, self.current_time)
temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time)
solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time)
load = self.data_manager.get_load_data(self.month, self.day, self.current_time)
heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time)
people = self.data_manager.get_people_data(self.month, self.day, self.current_time)
if self.TRAIN:
price = self.data_manager.get_price_data(self.month, self.day, self.current_time)
temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time)
solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time)
load = self.data_manager.get_load_data(self.month, self.day, self.current_time)
heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time)
people = self.data_manager.get_people_data(self.month, self.day, self.current_time)
else:
price = self.rec_data[0]
temper = self.rec_data[1]
solar = self.rec_data[2]
load = self.rec_data[3]
heat = self.rec_data[4]
people = self.rec_data[5]
obs = np.concatenate((np.float32(time_step), np.float32(price), np.float32(temper),
np.float32(solar), np.float32(load), np.float32(heat),
@ -78,10 +87,11 @@ class WgzGym(gym.Env):
self.HST.step(action[1])
self.grid.step(action[2])
price = current_obs[1]
temper = current_obs[2]
temper = current_obs[2] # 用途待补充
solar = current_obs[3]
load = current_obs[4]
heat = current_obs[5]
people = current_obs[6] # 用途待补充
power_gap = solar + self.HST.get_power() - self.EC.current_power - load
heat_gap = self.HST.get_heat() + self.EC.get_heat() - heat