From 56d3b34602c2978e19e9c3f77ca28270d20efaa7 Mon Sep 17 00:00:00 2001 From: default Date: Mon, 10 Mar 2025 07:52:23 +0000 Subject: [PATCH] modify ppo --- PPO.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/PPO.py b/PPO.py index 682e78f..9bba0a7 100644 --- a/PPO.py +++ b/PPO.py @@ -7,8 +7,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from environment import ESSEnv -from tools import get_episode_return, test_one_episode, optimization_base_result - +from tools import get_episode_return, test_one_episode +# from tools import optimization_base_result +torch.random.manual_seed(42) class ActorPPO(nn.Module): def __init__(self, mid_dim, state_dim, action_dim): @@ -304,7 +305,7 @@ if __name__ == '__main__': '''init buffer''' buffer = list() '''init training parameters''' - args.train = False + args.train = True args.save_network = False # args.test_network = False # args.save_test_data = False @@ -354,10 +355,11 @@ if __name__ == '__main__': '''compare with gurobi data and results''' if args.compare_with_gurobi: - month = record['init_info'][0][0] - day = record['init_info'][0][1] - initial_soc = record['init_info'][0][3] - base_result = optimization_base_result(env, month, day, initial_soc) + pass + # month = record['init_info'][0][0] + # day = record['init_info'][0][1] + # initial_soc = record['init_info'][0][3] + # base_result = optimization_base_result(env, month, day, initial_soc) if args.plot_on: from plotDRL import PlotArgs, make_dir, plot_evaluation_information, plot_optimization_result @@ -365,10 +367,10 @@ if __name__ == '__main__': plot_args.feature_change = '10' args.cwd = agent_name plot_dir = make_dir(args.cwd, plot_args.feature_change) - plot_optimization_result(base_result, plot_dir) + # plot_optimization_result(base_result, plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_10.pkl', plot_dir) '''compare the different cost get from gurobi and PPO''' - ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) + # ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) print('rl_cost:', sum(eval_data['operation_cost'])) - print('gurobi_cost:', sum(base_result['step_cost'])) - print('ration:', ration) + # print('gurobi_cost:', sum(base_result['step_cost'])) + # print('ration:', ration)