meeting
This commit is contained in:
parent
16eccf95b6
commit
7c186de43d
36
PPO.py
36
PPO.py
|
@ -203,20 +203,6 @@ class AgentPPO:
|
||||||
buf_advantage = buf_r_sum - (buf_mask * buf_value[:, 0])
|
buf_advantage = buf_r_sum - (buf_mask * buf_value[:, 0])
|
||||||
return buf_r_sum, buf_advantage
|
return buf_r_sum, buf_advantage
|
||||||
|
|
||||||
def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
|
|
||||||
buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
|
|
||||||
buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
|
|
||||||
|
|
||||||
pre_r_sum = 0
|
|
||||||
pre_advantage = 0 # advantage value of previous step
|
|
||||||
for i in range(buf_len - 1, -1, -1):
|
|
||||||
buf_r_sum[i] = ten_reward[i] + ten_mask[i] * gamma * pre_r_sum
|
|
||||||
pre_r_sum = buf_r_sum[i]
|
|
||||||
delta = ten_reward[i] + ten_mask[i] * gamma * ten_value[i + 1] - ten_value[i]
|
|
||||||
buf_advantage[i] = delta + ten_mask[i] * gamma * self.lambda_gae_adv * pre_advantage
|
|
||||||
pre_advantage = buf_advantage[i]
|
|
||||||
return buf_r_sum, buf_advantage
|
|
||||||
|
|
||||||
# def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
|
# def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
|
||||||
# buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
|
# buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
|
||||||
# buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
|
# buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
|
||||||
|
@ -224,12 +210,26 @@ class AgentPPO:
|
||||||
# pre_r_sum = 0
|
# pre_r_sum = 0
|
||||||
# pre_advantage = 0 # advantage value of previous step
|
# pre_advantage = 0 # advantage value of previous step
|
||||||
# for i in range(buf_len - 1, -1, -1):
|
# for i in range(buf_len - 1, -1, -1):
|
||||||
# buf_r_sum[i] = ten_reward[i] + ten_mask[i] * pre_r_sum
|
# buf_r_sum[i] = ten_reward[i] + ten_mask[i] * gamma * pre_r_sum
|
||||||
# pre_r_sum = buf_r_sum[i]
|
# pre_r_sum = buf_r_sum[i]
|
||||||
# buf_advantage[i] = ten_reward[i] + ten_mask[i] * (pre_advantage - ten_value[i]) # fix a bug here
|
# delta = ten_reward[i] + ten_mask[i] * gamma * ten_value[i + 1] - ten_value[i]
|
||||||
# pre_advantage = ten_value[i] + buf_advantage[i] * self.lambda_gae_adv
|
# buf_advantage[i] = delta + ten_mask[i] * gamma * self.lambda_gae_adv * pre_advantage
|
||||||
|
# pre_advantage = buf_advantage[i]
|
||||||
# return buf_r_sum, buf_advantage
|
# return buf_r_sum, buf_advantage
|
||||||
|
|
||||||
|
def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor):
|
||||||
|
buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value
|
||||||
|
buf_advantage = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value
|
||||||
|
|
||||||
|
pre_r_sum = 0.0
|
||||||
|
pre_advantage = 0.0 # advantage value of previous step
|
||||||
|
for i in range(buf_len - 1, -1, -1):
|
||||||
|
buf_r_sum[i] = ten_reward[i] + ten_mask[i] * pre_r_sum
|
||||||
|
pre_r_sum = buf_r_sum[i]
|
||||||
|
buf_advantage[i] = ten_reward[i] + ten_mask[i] * (pre_advantage - ten_value[i]) # fix a bug here
|
||||||
|
pre_advantage = ten_value[i] + buf_advantage[i] * self.lambda_gae_adv
|
||||||
|
return buf_r_sum, buf_advantage
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def optim_update(optimizer, objective):
|
def optim_update(optimizer, objective):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -399,7 +399,7 @@ if __name__ == '__main__':
|
||||||
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
|
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
|
||||||
|
|
||||||
plot_args = PlotArgs()
|
plot_args = PlotArgs()
|
||||||
plot_args.feature_change = 'gae_solar'
|
plot_args.feature_change = 'gae'
|
||||||
args.cwd = agent_name
|
args.cwd = agent_name
|
||||||
plot_dir = make_dir(args.cwd, plot_args.feature_change)
|
plot_dir = make_dir(args.cwd, plot_args.feature_change)
|
||||||
plot_optimization_result(base_result, plot_dir)
|
plot_optimization_result(base_result, plot_dir)
|
||||||
|
|
Loading…
Reference in New Issue