import numpy as np import torch import torch.nn as nn class Actor(nn.Module): def __init__(self, mid_dim, state_dim, action_dim): super().__init__() self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, action_dim)) def forward(self, state): return self.net(state).tanh() # action.tanh() def get_action(self, state, action_std): action = self.net(state).tanh() noise = (torch.randn_like(action) * action_std).clamp(-0.5, 0.5) return (action + noise).clamp(-1.0, 1.0) class ActorSAC(nn.Module): def __init__(self, mid_dim, state_dim, action_dim): super().__init__() self.net_state = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU(), ) self.net_a_avg = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, action_dim)) # the average of action self.net_a_std = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, action_dim)) # the log_std of action self.log_sqrt_2pi = np.log(np.sqrt(2 * np.pi)) def forward(self, state): tmp = self.net_state(state) return self.net_a_avg(tmp).tanh() # action def get_action(self, state): t_tmp = self.net_state(state) a_avg = self.net_a_avg(t_tmp) # NOTICE! it is a_avg without .tanh() a_std = self.net_a_std(t_tmp).clamp(-20, 2).exp() return torch.normal(a_avg, a_std).tanh() # re-parameterize def get_action_logprob(self, state): t_tmp = self.net_state(state) a_avg = self.net_a_avg(t_tmp) # NOTICE! it needs a_avg.tanh() a_std_log = self.net_a_std(t_tmp).clamp(-20, 2) a_std = a_std_log.exp() noise = torch.randn_like(a_avg, requires_grad=True) a_tan = (a_avg + a_std * noise).tanh() # action.tanh() log_prob = a_std_log + self.log_sqrt_2pi + noise.pow(2).__mul__(0.5) # noise.pow(2) * 0.5 log_prob = log_prob + (-a_tan.pow(2) + 1.000001).log() # fix log_prob using the derivative of action.tanh() return a_tan, log_prob.sum(1, keepdim=True) class ActorPPO(nn.Module): def __init__(self, mid_dim, state_dim, action_dim): super().__init__() self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, action_dim), ) # the logarithm (log) of standard deviation (std) of action, it is a trainable parameter self.a_std_log = nn.Parameter(torch.zeros((1, action_dim)) - 0.5, requires_grad=True) self.sqrt_2pi_log = np.log(np.sqrt(2 * np.pi)) def forward(self, state): return self.net(state).tanh() # action.tanh()# in this way limit the data output of action def get_action(self, state): a_avg = self.net(state) # mean a_std = self.a_std_log.exp() # standard deviation noise = torch.randn_like(a_avg) action = a_avg + noise * a_std return action, noise def get_logprob_entropy(self, state, action): a_avg = self.net(state) a_std = self.a_std_log.exp() delta = ((a_avg - action) / a_std).pow(2) * 0.5 logprob = -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # new_logprob dist_entropy = (logprob.exp() * logprob).mean() # policy entropy return logprob, dist_entropy def get_old_logprob(self, _action, noise): # noise = action - a_noise delta = noise.pow(2) * 0.5 return -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # old_logprob class Critic(nn.Module): def __init__(self, mid_dim, state_dim, action_dim): super().__init__() self.net = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, 1)) def forward(self, state, action): return self.net(torch.cat((state, action), dim=1)) # q value class CriticAdv(nn.Module): def __init__(self, mid_dim, state_dim, _action_dim): super().__init__() self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, 1)) def forward(self, state): return self.net(state) # advantage value class CriticTwin(nn.Module): # shared parameter def __init__(self, mid_dim, state_dim, action_dim): super().__init__() self.net_sa = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(), nn.Linear(mid_dim, mid_dim), nn.ReLU()) # concat(state, action) self.net_q1 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, 1)) # q1 value self.net_q2 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(), nn.Linear(mid_dim, 1)) # q2 value def forward(self, state, action): tmp = self.net_sa(torch.cat((state, action), dim=1)) return self.net_q1(tmp) # one Q value def get_q1_q2(self, state, action): tmp = self.net_sa(torch.cat((state, action), dim=1)) return self.net_q1(tmp), self.net_q2(tmp) # two Q values