building-agents/net.py

137 lines
5.9 KiB
Python

import numpy as np
import torch
import torch.nn as nn
class Actor(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, action_dim))
def forward(self, state):
return self.net(state).tanh() # action.tanh()
def get_action(self, state, action_std):
action = self.net(state).tanh()
noise = (torch.randn_like(action) * action_std).clamp(-0.5, 0.5)
return (action + noise).clamp(-1.0, 1.0)
class ActorSAC(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim):
super().__init__()
self.net_state = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU(), )
self.net_a_avg = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, action_dim)) # the average of action
self.net_a_std = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, action_dim)) # the log_std of action
self.log_sqrt_2pi = np.log(np.sqrt(2 * np.pi))
def forward(self, state):
tmp = self.net_state(state)
return self.net_a_avg(tmp).tanh() # action
def get_action(self, state):
t_tmp = self.net_state(state)
a_avg = self.net_a_avg(t_tmp) # NOTICE! it is a_avg without .tanh()
a_std = self.net_a_std(t_tmp).clamp(-20, 2).exp()
return torch.normal(a_avg, a_std).tanh() # re-parameterize
def get_action_logprob(self, state):
t_tmp = self.net_state(state)
a_avg = self.net_a_avg(t_tmp) # NOTICE! it needs a_avg.tanh()
a_std_log = self.net_a_std(t_tmp).clamp(-20, 2)
a_std = a_std_log.exp()
noise = torch.randn_like(a_avg, requires_grad=True)
a_tan = (a_avg + a_std * noise).tanh() # action.tanh()
log_prob = a_std_log + self.log_sqrt_2pi + noise.pow(2).__mul__(0.5) # noise.pow(2) * 0.5
log_prob = log_prob + (-a_tan.pow(2) + 1.000001).log() # fix log_prob using the derivative of action.tanh()
return a_tan, log_prob.sum(1, keepdim=True)
class ActorPPO(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, action_dim), )
# the logarithm (log) of standard deviation (std) of action, it is a trainable parameter
self.a_std_log = nn.Parameter(torch.zeros((1, action_dim)) - 0.5, requires_grad=True)
self.sqrt_2pi_log = np.log(np.sqrt(2 * np.pi))
def forward(self, state):
return self.net(state).tanh() # action.tanh()# in this way limit the data output of action
def get_action(self, state):
a_avg = self.net(state) # mean
a_std = self.a_std_log.exp() # standard deviation
noise = torch.randn_like(a_avg)
action = a_avg + noise * a_std
return action, noise
def get_logprob_entropy(self, state, action):
a_avg = self.net(state)
a_std = self.a_std_log.exp()
delta = ((a_avg - action) / a_std).pow(2) * 0.5
logprob = -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # new_logprob
dist_entropy = (logprob.exp() * logprob).mean() # policy entropy
return logprob, dist_entropy
def get_old_logprob(self, _action, noise): # noise = action - a_noise
delta = noise.pow(2) * 0.5
return -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # old_logprob
class Critic(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, 1))
def forward(self, state, action):
return self.net(torch.cat((state, action), dim=1)) # q value
class CriticAdv(nn.Module):
def __init__(self, mid_dim, state_dim, _action_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, 1))
def forward(self, state):
return self.net(state) # advantage value
class CriticTwin(nn.Module): # shared parameter
def __init__(self, mid_dim, state_dim, action_dim):
super().__init__()
self.net_sa = nn.Sequential(nn.Linear(state_dim + action_dim, mid_dim), nn.ReLU(),
nn.Linear(mid_dim, mid_dim), nn.ReLU()) # concat(state, action)
self.net_q1 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, 1)) # q1 value
self.net_q2 = nn.Sequential(nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
nn.Linear(mid_dim, 1)) # q2 value
def forward(self, state, action):
tmp = self.net_sa(torch.cat((state, action), dim=1))
return self.net_q1(tmp) # one Q value
def get_q1_q2(self, state, action):
tmp = self.net_sa(torch.cat((state, action), dim=1))
return self.net_q1(tmp), self.net_q2(tmp) # two Q values