This commit is contained in:
chenxiaodong 2024-07-30 13:31:13 +08:00
parent 16eccf95b6
commit 7c186de43d
1 changed files with 18 additions and 18 deletions

36
PPO.py
View File

@ -203,20 +203,6 @@ class AgentPPO:
buf_advantage = buf_r_sum - (buf_mask * buf_value[:, 0])
return buf_r_sum, buf_advantage
def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
pre_r_sum = 0
pre_advantage = 0 # advantage value of previous step
for i in range(buf_len - 1, -1, -1):
buf_r_sum[i] = ten_reward[i] + ten_mask[i] * gamma * pre_r_sum
pre_r_sum = buf_r_sum[i]
delta = ten_reward[i] + ten_mask[i] * gamma * ten_value[i + 1] - ten_value[i]
buf_advantage[i] = delta + ten_mask[i] * gamma * self.lambda_gae_adv * pre_advantage
pre_advantage = buf_advantage[i]
return buf_r_sum, buf_advantage
# def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
# buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
# buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
@ -224,12 +210,26 @@ class AgentPPO:
# pre_r_sum = 0
# pre_advantage = 0 # advantage value of previous step
# for i in range(buf_len - 1, -1, -1):
# buf_r_sum[i] = ten_reward[i] + ten_mask[i] * pre_r_sum
# buf_r_sum[i] = ten_reward[i] + ten_mask[i] * gamma * pre_r_sum
# pre_r_sum = buf_r_sum[i]
# buf_advantage[i] = ten_reward[i] + ten_mask[i] * (pre_advantage - ten_value[i]) # fix a bug here
# pre_advantage = ten_value[i] + buf_advantage[i] * self.lambda_gae_adv
# delta = ten_reward[i] + ten_mask[i] * gamma * ten_value[i + 1] - ten_value[i]
# buf_advantage[i] = delta + ten_mask[i] * gamma * self.lambda_gae_adv * pre_advantage
# pre_advantage = buf_advantage[i]
# return buf_r_sum, buf_advantage
def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
pre_r_sum = 0.0
pre_advantage = 0.0 # advantage value of previous step
for i in range(buf_len - 1, -1, -1):
buf_r_sum[i] = ten_reward[i] + ten_mask[i] * pre_r_sum
pre_r_sum = buf_r_sum[i]
buf_advantage[i] = ten_reward[i] + ten_mask[i] * (pre_advantage - ten_value[i]) # fix a bug here
pre_advantage = ten_value[i] + buf_advantage[i] * self.lambda_gae_adv
return buf_r_sum, buf_advantage
@staticmethod
def optim_update(optimizer, objective):
optimizer.zero_grad()
@ -399,7 +399,7 @@ if __name__ == '__main__':
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
plot_args = PlotArgs()
plot_args.feature_change = 'gae_solar'
plot_args.feature_change = 'gae'
args.cwd = agent_name
plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir)