modify ppo
This commit is contained in:
parent
83b1abbf76
commit
56d3b34602
24
PPO.py
24
PPO.py
|
@ -7,8 +7,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from environment import ESSEnv
|
||||
from tools import get_episode_return, test_one_episode, optimization_base_result
|
||||
|
||||
from tools import get_episode_return, test_one_episode
|
||||
# from tools import optimization_base_result
|
||||
torch.random.manual_seed(42)
|
||||
|
||||
class ActorPPO(nn.Module):
|
||||
def __init__(self, mid_dim, state_dim, action_dim):
|
||||
|
@ -304,7 +305,7 @@ if __name__ == '__main__':
|
|||
'''init buffer'''
|
||||
buffer = list()
|
||||
'''init training parameters'''
|
||||
args.train = False
|
||||
args.train = True
|
||||
args.save_network = False
|
||||
# args.test_network = False
|
||||
# args.save_test_data = False
|
||||
|
@ -354,10 +355,11 @@ if __name__ == '__main__':
|
|||
|
||||
'''compare with gurobi data and results'''
|
||||
if args.compare_with_gurobi:
|
||||
month = record['init_info'][0][0]
|
||||
day = record['init_info'][0][1]
|
||||
initial_soc = record['init_info'][0][3]
|
||||
base_result = optimization_base_result(env, month, day, initial_soc)
|
||||
pass
|
||||
# month = record['init_info'][0][0]
|
||||
# day = record['init_info'][0][1]
|
||||
# initial_soc = record['init_info'][0][3]
|
||||
# base_result = optimization_base_result(env, month, day, initial_soc)
|
||||
if args.plot_on:
|
||||
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
|
||||
|
||||
|
@ -365,10 +367,10 @@ if __name__ == '__main__':
|
|||
plot_args.feature_change = '10'
|
||||
args.cwd = agent_name
|
||||
plot_dir = make_dir(args.cwd, plot_args.feature_change)
|
||||
plot_optimization_result(base_result, plot_dir)
|
||||
# plot_optimization_result(base_result, plot_dir)
|
||||
plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir)
|
||||
'''compare the different cost get from gurobi and PPO'''
|
||||
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
|
||||
# ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
|
||||
print('rl_cost:', sum(eval_data['operation_cost']))
|
||||
print('gurobi_cost:', sum(base_result['step_cost']))
|
||||
print('ration:', ration)
|
||||
# print('gurobi_cost:', sum(base_result['step_cost']))
|
||||
# print('ration:', ration)
|
||||
|
|
Loading…
Reference in New Issue