modify ppo

This commit is contained in:
default 2025-03-10 07:52:23 +00:00
parent 83b1abbf76
commit 56d3b34602
1 changed files with 13 additions and 11 deletions

24
PPO.py
View File

@ -7,8 +7,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from environment import ESSEnv
from tools import get_episode_return, test_one_episode, optimization_base_result
from tools import get_episode_return, test_one_episode
# from tools import optimization_base_result
torch.random.manual_seed(42)
class ActorPPO(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim):
@ -304,7 +305,7 @@ if __name__ == '__main__':
'''init buffer'''
buffer = list()
'''init training parameters'''
args.train = False
args.train = True
args.save_network = False
# args.test_network = False
# args.save_test_data = False
@ -354,10 +355,11 @@ if __name__ == '__main__':
'''compare with gurobi data and results'''
if args.compare_with_gurobi:
month = record['init_info'][0][0]
day = record['init_info'][0][1]
initial_soc = record['init_info'][0][3]
base_result = optimization_base_result(env, month, day, initial_soc)
pass
# month = record['init_info'][0][0]
# day = record['init_info'][0][1]
# initial_soc = record['init_info'][0][3]
# base_result = optimization_base_result(env, month, day, initial_soc)
if args.plot_on:
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
@ -365,10 +367,10 @@ if __name__ == '__main__':
plot_args.feature_change = '10'
args.cwd = agent_name
plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir)
# plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir)
'''compare the different cost get from gurobi and PPO'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
# ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('rl_cost:', sum(eval_data['operation_cost']))
print('gurobi_cost:', sum(base_result['step_cost']))
print('ration:', ration)
# print('gurobi_cost:', sum(base_result['step_cost']))
# print('ration:', ration)