52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class ActorPPO(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, action_dim))
|
|
# the logarithm (log) of standard deviation (std) of action, it is a trainable parameter
|
|
self.a_logstd = nn.Parameter(torch.zeros((1, action_dim)) - 0.5, requires_grad=True)
|
|
self.sqrt_2pi_log = np.log(np.sqrt(2 * np.pi))
|
|
|
|
def forward(self, state):
|
|
return self.net(state).tanh() # action.tanh() limit the data output of action
|
|
|
|
def get_action(self, state):
|
|
a_avg = self.forward(state) # too big for the action
|
|
a_std = self.a_logstd.exp()
|
|
|
|
noise = torch.randn_like(a_avg)
|
|
action = a_avg + noise * a_std
|
|
return action, noise
|
|
|
|
def get_logprob_entropy(self, state, action):
|
|
a_avg = self.forward(state)
|
|
a_std = self.a_logstd.exp()
|
|
delta = ((a_avg - action) / a_std).pow(2) * 0.5
|
|
logprob = -(self.a_logstd + self.sqrt_2pi_log + delta).sum(1) # new_logprob
|
|
|
|
dist_entropy = (logprob.exp() * logprob).mean() # policy entropy
|
|
return logprob, dist_entropy
|
|
|
|
def get_old_logprob(self, _action, noise): # noise = action - a_noise
|
|
delta = noise.pow(2) * 0.5
|
|
return -(self.a_logstd + self.sqrt_2pi_log + delta).sum(1) # old_logprob
|
|
|
|
|
|
class CriticAdv(nn.Module):
|
|
def __init__(self, mid_dim, state_dim, _action_dim):
|
|
super().__init__()
|
|
self.net = nn.Sequential(nn.Linear(state_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
|
|
nn.Linear(mid_dim, mid_dim), nn.Hardswish(),
|
|
nn.Linear(mid_dim, 1))
|
|
|
|
def forward(self, state):
|
|
return self.net(state) # Advantage value
|