This commit is contained in:
chenxiaodong 2024-06-20 09:19:22 +08:00
parent 69fe33deec
commit 21be0966dc
1 changed files with 3 additions and 3 deletions

6
PPO.py
View File

@ -106,14 +106,14 @@ class AgentPPO:
self.get_reward_sum = None # self.get_reward_sum_gae if if_use_gae else self.get_reward_sum_raw self.get_reward_sum = None # self.get_reward_sum_gae if if_use_gae else self.get_reward_sum_raw
self.trajectory_list = None self.trajectory_list = None
def init(self, net_dim, state_dim, action_dim, learning_rate=1e-4, if_use_gae=False, gpu_id=0): def init(self, net_dim, state_dim, action_dim, learning_rate=1e-4, if_use_gae=False, gpu_id=0, layer_norm=False):
self.device = torch.device(f"cuda:{gpu_id}" if (torch.cuda.is_available() and (gpu_id >= 0)) else "cpu") self.device = torch.device(f"cuda:{gpu_id}" if (torch.cuda.is_available() and (gpu_id >= 0)) else "cpu")
self.trajectory_list = list() self.trajectory_list = list()
# choose whether to use gae or not # choose whether to use gae or not
self.get_reward_sum = self.get_reward_sum_gae if if_use_gae else self.get_reward_sum_raw self.get_reward_sum = self.get_reward_sum_gae if if_use_gae else self.get_reward_sum_raw
self.cri = self.ClassCri(net_dim, state_dim, action_dim).to(self.device) self.cri = self.ClassCri(net_dim, state_dim, action_dim, layer_norm).to(self.device)
self.act = self.ClassAct(net_dim, state_dim, action_dim).to(self.device) if self.ClassAct else self.cri self.act = self.ClassAct(net_dim, state_dim, action_dim, layer_norm).to(self.device) if self.ClassAct else self.cri
self.cri_target = deepcopy(self.cri) if self.if_use_cri_target else self.cri self.cri_target = deepcopy(self.cri) if self.if_use_cri_target else self.cri
self.act_target = deepcopy(self.act) if self.if_use_act_target else self.act self.act_target = deepcopy(self.act) if self.if_use_act_target else self.act