modify ppo
This commit is contained in:
parent
83b1abbf76
commit
56d3b34602
24
PPO.py
24
PPO.py
|
@ -7,8 +7,9 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from environment import ESSEnv
|
from environment import ESSEnv
|
||||||
from tools import get_episode_return, test_one_episode, optimization_base_result
|
from tools import get_episode_return, test_one_episode
|
||||||
|
# from tools import optimization_base_result
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
|
||||||
class ActorPPO(nn.Module):
|
class ActorPPO(nn.Module):
|
||||||
def __init__(self, mid_dim, state_dim, action_dim):
|
def __init__(self, mid_dim, state_dim, action_dim):
|
||||||
|
@ -304,7 +305,7 @@ if __name__ == '__main__':
|
||||||
'''init buffer'''
|
'''init buffer'''
|
||||||
buffer = list()
|
buffer = list()
|
||||||
'''init training parameters'''
|
'''init training parameters'''
|
||||||
args.train = False
|
args.train = True
|
||||||
args.save_network = False
|
args.save_network = False
|
||||||
# args.test_network = False
|
# args.test_network = False
|
||||||
# args.save_test_data = False
|
# args.save_test_data = False
|
||||||
|
@ -354,10 +355,11 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
'''compare with gurobi data and results'''
|
'''compare with gurobi data and results'''
|
||||||
if args.compare_with_gurobi:
|
if args.compare_with_gurobi:
|
||||||
month = record['init_info'][0][0]
|
pass
|
||||||
day = record['init_info'][0][1]
|
# month = record['init_info'][0][0]
|
||||||
initial_soc = record['init_info'][0][3]
|
# day = record['init_info'][0][1]
|
||||||
base_result = optimization_base_result(env, month, day, initial_soc)
|
# initial_soc = record['init_info'][0][3]
|
||||||
|
# base_result = optimization_base_result(env, month, day, initial_soc)
|
||||||
if args.plot_on:
|
if args.plot_on:
|
||||||
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
|
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
|
||||||
|
|
||||||
|
@ -365,10 +367,10 @@ if __name__ == '__main__':
|
||||||
plot_args.feature_change = '10'
|
plot_args.feature_change = '10'
|
||||||
args.cwd = agent_name
|
args.cwd = agent_name
|
||||||
plot_dir = make_dir(args.cwd, plot_args.feature_change)
|
plot_dir = make_dir(args.cwd, plot_args.feature_change)
|
||||||
plot_optimization_result(base_result, plot_dir)
|
# plot_optimization_result(base_result, plot_dir)
|
||||||
plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir)
|
plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir)
|
||||||
'''compare the different cost get from gurobi and PPO'''
|
'''compare the different cost get from gurobi and PPO'''
|
||||||
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
|
# ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
|
||||||
print('rl_cost:', sum(eval_data['operation_cost']))
|
print('rl_cost:', sum(eval_data['operation_cost']))
|
||||||
print('gurobi_cost:', sum(base_result['step_cost']))
|
# print('gurobi_cost:', sum(base_result['step_cost']))
|
||||||
print('ration:', ration)
|
# print('ration:', ration)
|
||||||
|
|
Loading…
Reference in New Issue