diff --git a/DDPG.py b/DDPG.py index 0f54c58..a79258a 100644 --- a/DDPG.py +++ b/DDPG.py @@ -7,6 +7,19 @@ from environment import ESSEnv from tools import * +def update_buffer(_trajectory): + ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32) + ary_other = torch.as_tensor([item[1] for item in _trajectory]) + ary_other[:, 0] = ary_other[:, 0] # ten_reward + ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma + + buffer.extend_buffer(ten_state, ary_other) + + _steps = ten_state.shape[0] + _r_exp = ary_other[:, 0].mean() # other = (reward, mask, action) + return _steps, _r_exp + + if __name__ == '__main__': args = Arguments() '''record real unbalance''' diff --git a/SAC.py b/SAC.py index 1754cfb..080f330 100644 --- a/SAC.py +++ b/SAC.py @@ -7,6 +7,19 @@ from environment import ESSEnv from tools import * +def update_buffer(_trajectory): + ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32) + ary_other = torch.as_tensor([item[1] for item in _trajectory]) + ary_other[:, 0] = ary_other[:, 0] # ten_reward + ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma + + buffer.extend_buffer(ten_state, ary_other) + + _steps = ten_state.shape[0] + _r_exp = ary_other[:, 0].mean() # other = (reward, mask, action) + return _steps, _r_exp + + if __name__ == '__main__': args = Arguments() reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} diff --git a/TD3.py b/TD3.py index 109bafd..aa2f390 100644 --- a/TD3.py +++ b/TD3.py @@ -7,6 +7,19 @@ from environment import ESSEnv from tools import * +def update_buffer(_trajectory): + ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32) + ary_other = torch.as_tensor([item[1] for item in _trajectory]) + ary_other[:, 0] = ary_other[:, 0] # ten_reward + ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma + + buffer.extend_buffer(ten_state, ary_other) + + _steps = ten_state.shape[0] + _r_exp = ary_other[:, 0].mean() # other = (reward, mask, action) + return _steps, _r_exp + + if __name__ == '__main__': args = Arguments() reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} diff --git a/tools.py b/tools.py index 6edd607..a4e4926 100644 --- a/tools.py +++ b/tools.py @@ -233,19 +233,6 @@ def get_episode_return(env, act, device): return episode_return, episode_unbalance -def update_buffer(_trajectory): - ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32) - ary_other = torch.as_tensor([item[1] for item in _trajectory]) - ary_other[:, 0] = ary_other[:, 0] # ten_reward - ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma - - buffer.extend_buffer(ten_state, ary_other) - - _steps = ten_state.shape[0] - _r_exp = ary_other[:, 0].mean() # other = (reward, mask, action) - return _steps, _r_exp - - class ReplayBuffer: def __init__(self, max_len, state_dim, action_dim, gpu_id=0): self.now_len = 0