This commit is contained in:
chenxiaodong 2024-06-20 09:11:09 +08:00
parent b5a1842147
commit 69fe33deec
4 changed files with 39 additions and 13 deletions

13
DDPG.py
View File

@ -7,6 +7,19 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
'''record real unbalance''' '''record real unbalance'''

13
SAC.py
View File

@ -7,6 +7,19 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []}

13
TD3.py
View File

@ -7,6 +7,19 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []}

View File

@ -233,19 +233,6 @@ def get_episode_return(env, act, device):
return episode_return, episode_unbalance return episode_return, episode_unbalance
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
class ReplayBuffer: class ReplayBuffer:
def __init__(self, max_len, state_dim, action_dim, gpu_id=0): def __init__(self, max_len, state_dim, action_dim, gpu_id=0):
self.now_len = 0 self.now_len = 0