modify ppo

This commit is contained in:
default 2025-03-10 07:52:23 +00:00
parent 83b1abbf76
commit 56d3b34602
1 changed files with 13 additions and 11 deletions

24
PPO.py
View File

@ -7,8 +7,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from environment import ESSEnv from environment import ESSEnv
from tools import get_episode_return, test_one_episode, optimization_base_result from tools import get_episode_return, test_one_episode
# from tools import optimization_base_result
torch.random.manual_seed(42)
class ActorPPO(nn.Module): class ActorPPO(nn.Module):
def __init__(self, mid_dim, state_dim, action_dim): def __init__(self, mid_dim, state_dim, action_dim):
@ -304,7 +305,7 @@ if __name__ == '__main__':
'''init buffer''' '''init buffer'''
buffer = list() buffer = list()
'''init training parameters''' '''init training parameters'''
args.train = False args.train = True
args.save_network = False args.save_network = False
# args.test_network = False # args.test_network = False
# args.save_test_data = False # args.save_test_data = False
@ -354,10 +355,11 @@ if __name__ == '__main__':
'''compare with gurobi data and results''' '''compare with gurobi data and results'''
if args.compare_with_gurobi: if args.compare_with_gurobi:
month = record['init_info'][0][0] pass
day = record['init_info'][0][1] # month = record['init_info'][0][0]
initial_soc = record['init_info'][0][3] # day = record['init_info'][0][1]
base_result = optimization_base_result(env, month, day, initial_soc) # initial_soc = record['init_info'][0][3]
# base_result = optimization_base_result(env, month, day, initial_soc)
if args.plot_on: if args.plot_on:
from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result
@ -365,10 +367,10 @@ if __name__ == '__main__':
plot_args.feature_change = '10' plot_args.feature_change = '10'
args.cwd = agent_name args.cwd = agent_name
plot_dir = make_dir(args.cwd, plot_args.feature_change) plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir) # plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir)
'''compare the different cost get from gurobi and PPO''' '''compare the different cost get from gurobi and PPO'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) # ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('rl_cost:', sum(eval_data['operation_cost'])) print('rl_cost:', sum(eval_data['operation_cost']))
print('gurobi_cost:', sum(base_result['step_cost'])) # print('gurobi_cost:', sum(base_result['step_cost']))
print('ration:', ration) # print('ration:', ration)