137 lines
5.9 KiB
Python
137 lines
5.9 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Actor(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, action_dim))
|
|
|
|
def forward(self, state):
|
|
return self.net(state).tanh() # action.tanh()
|
|
|
|
def get_action(self, state, action_std):
|
|
action = self.net(state).tanh()
|
|
noise = (torch.randn_like(action) * action_std).clamp(-0.5, 0.5)
|
|
return (action + noise).clamp(-1.0, 1.0)
|
|
|
|
|
|
class ActorSAC(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net_state = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(), )
|
|
self.net_a_avg = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, action_dim)) # the average of action
|
|
self.net_a_std = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, action_dim)) # the log_std of action
|
|
self.log_sqrt_2pi = np.log(np.sqrt(2 * np.pi))
|
|
|
|
def forward(self, state):
|
|
tmp = self.net_state(state)
|
|
return self.net_a_avg(tmp).tanh() # action
|
|
|
|
def get_action(self, state):
|
|
t_tmp = self.net_state(state)
|
|
a_avg = self.net_a_avg(t_tmp) # NOTICE! it is a_avg without .tanh()
|
|
a_std = self.net_a_std(t_tmp).clamp(-20, 2).exp()
|
|
return torch.normal(a_avg, a_std).tanh() # re-parameterize
|
|
|
|
def get_action_logprob(self, state):
|
|
t_tmp = self.net_state(state)
|
|
a_avg = self.net_a_avg(t_tmp) # NOTICE! it needs a_avg.tanh()
|
|
a_std_log = self.net_a_std(t_tmp).clamp(-20, 2)
|
|
a_std = a_std_log.exp()
|
|
|
|
noise = torch.randn_like(a_avg, requires_grad=True)
|
|
a_tan = (a_avg + a_std * noise).tanh() # action.tanh()
|
|
|
|
log_prob = a_std_log + self.log_sqrt_2pi + noise.pow(2).__mul__(0.5) # noise.pow(2) * 0.5
|
|
log_prob = log_prob + (-a_tan.pow(2) + 1.000001).log() # fix log_prob using the derivative of action.tanh()
|
|
return a_tan, log_prob.sum(1, keepdim=True)
|
|
|
|
|
|
class ActorPPO(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, action_dim), )
|
|
|
|
# the logarithm (log) of standard deviation (std) of action, it is a trainable parameter
|
|
self.a_std_log = nn.Parameter(torch.zeros((1, action_dim)) - 0.5, requires_grad=True)
|
|
self.sqrt_2pi_log = np.log(np.sqrt(2 * np.pi))
|
|
|
|
def forward(self, state):
|
|
return self.net(state).tanh() # action.tanh()# in this way limit the data output of action
|
|
|
|
def get_action(self, state):
|
|
a_avg = self.net(state) # mean
|
|
a_std = self.a_std_log.exp() # standard deviation
|
|
|
|
noise = torch.randn_like(a_avg)
|
|
action = a_avg + noise * a_std
|
|
return action, noise
|
|
|
|
def get_logprob_entropy(self, state, action):
|
|
a_avg = self.net(state)
|
|
a_std = self.a_std_log.exp()
|
|
|
|
delta = ((a_avg - action) / a_std).pow(2) * 0.5
|
|
logprob = -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # new_logprob
|
|
|
|
dist_entropy = (logprob.exp() * logprob).mean() # policy entropy
|
|
return logprob, dist_entropy
|
|
|
|
def get_old_logprob(self, _action, noise): # noise = action - a_noise
|
|
delta = noise.pow(2) * 0.5
|
|
return -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # old_logprob
|
|
|
|
|
|
class Critic(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, 1))
|
|
|
|
def forward(self, state, action):
|
|
return self.net(torch.cat((state, action), dim=1)) # q value
|
|
|
|
|
|
class CriticAdv(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, _action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, 1))
|
|
|
|
def forward(self, state):
|
|
return self.net(state) # advantage value
|
|
|
|
|
|
class CriticTwin(nn.Module): # shared parameter
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net_sa = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU()) # concat(state, action)
|
|
self.net_q1 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, 1)) # q1 value
|
|
self.net_q2 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, 1)) # q2 value
|
|
|
|
def forward(self, state, action):
|
|
tmp = self.net_sa(torch.cat((state, action), dim=1))
|
|
return self.net_q1(tmp) # one Q value
|
|
|
|
def get_q1_q2(self, state, action):
|
|
tmp = self.net_sa(torch.cat((state, action), dim=1))
|
|
return self.net_q1(tmp), self.net_q2(tmp) # two Q values
|