wgz_decision/models/tools.py

41 lines
1.7 KiB
Python
Raw Permalink Normal View History

2025-02-11 14:38:52 +08:00
import torch
def test_one_episode(env, act, device):
2025-02-13 15:05:12 +08:00
"""get evaluate information, record the unbalance of after taking action"""
record_system_info = [] # same as observation
record_init_info = [] # include month,day,time
2025-02-11 14:38:52 +08:00
env.TRAIN = False
state = env.reset()
2025-02-13 15:05:12 +08:00
record_init_info.append([env.month, env.day, env.current_time])
print(f'current testing month is {env.month}, day is {env.day}')
2025-02-11 14:38:52 +08:00
for i in range(24):
s_tensor = torch.as_tensor((state,), device=device)
a_tensor = act(s_tensor)
2025-02-13 15:05:12 +08:00
action = a_tensor.detach().cpu().numpy()[0]
2025-02-11 14:38:52 +08:00
state, next_state, reward, done = env.step(action)
2025-02-17 09:45:15 +08:00
record_system_info.append([state[1], state[2], env.HST.current_soc(), env.HST.get_power(),
2025-02-17 10:18:29 +08:00
env.EC.current_power, action, reward])
2025-02-11 14:38:52 +08:00
state = next_state
2025-02-13 15:05:12 +08:00
# add information of last step EC, HST.current_soc, HST.power, grid
2025-02-17 10:18:29 +08:00
record_system_info[-1][2:5] = [env.final_step_outputs[0], env.final_step_outputs[1], env.final_step_outputs[2]]
2025-02-13 15:05:12 +08:00
record = {'init_info': record_init_info, 'system_info': record_system_info}
2025-02-11 14:38:52 +08:00
return record
def get_episode_return(env, act, device):
episode_reward = 0.0 # sum of rewards in an episode
episode_unbalance = 0.0
state = env.reset()
for i in range(24):
s_tensor = torch.as_tensor((state,), device=device)
a_tensor = act(s_tensor)
action = a_tensor.detach().cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside
state, next_state, reward, done, = env.step(action)
state = next_state
episode_reward += reward
2025-02-17 09:45:15 +08:00
episode_unbalance += env.unbalance
2025-02-11 14:38:52 +08:00
if done:
break
return episode_reward, episode_unbalance