refactor(train)

This commit is contained in:
chenxiaodong 2025-02-18 08:49:32 +08:00
parent d0ea09d663
commit a3ad4182f1
9 changed files with 27 additions and 30 deletions

BIN
data/actor.pth Normal file

Binary file not shown.

BIN
data/loss.pkl Normal file

Binary file not shown.

BIN
data/reward.pkl Normal file

Binary file not shown.

5
data/service_actions.csv Normal file
View File

@ -0,0 +1,5 @@
time,action
1,[-0.85844654 -0.913628 ]
1,[-0.97137856 -0.9997079 ]
1,[-0.97137856 -0.9997079 ]
1,[-0.97137856 -0.9997079 ]
1 time action
2 1 [-0.85844654 -0.913628 ]
3 1 [-0.97137856 -0.9997079 ]
4 1 [-0.97137856 -0.9997079 ]
5 1 [-0.97137856 -0.9997079 ]

4
data/service_result.csv Normal file
View File

@ -0,0 +1,4 @@
reward,unbalance
-0.09503999999999999,0.03
-0.09503999999999999,0.03
-0.09503999999999999,0.03
1 reward unbalance
2 -0.09503999999999999 0.03
3 -0.09503999999999999 0.03
4 -0.09503999999999999 0.03

View File

@ -1,7 +1,6 @@
import queue
import threading
import time
import torch
from train import *
@ -18,7 +17,7 @@ def test_one_step(env, act, device, data, action_path):
with open(action_path, 'a') as af:
af.write(f'{env.current_time},{action}\n')
return reward, env.real_unbalance
return reward, env.unbalance
def run_service_test(env, agent, data):
@ -63,14 +62,15 @@ def listener_thread(env, agent, data_queue):
def sender_thread(data_queue):
while True:
try:
user_input = input("请输入price, temper, solar, load, heat, people用逗号分隔: ")
time.sleep(0.5)
user_input = input("请输入当前时刻的price, temper, solar, load, heat, people用逗号分隔: \n")
# 将输入字符串分割并转换为浮点数列表
input_data = list(map(float, user_input.split(',')))
# 检查输入是否包含六个数值
if len(input_data) != 6:
print("输入数据格式不正确,请输入六个数值。")
print("输入格式不正确,请输入六个数值。")
continue
# 将数据放入队列
@ -78,7 +78,7 @@ def sender_thread(data_queue):
data_queue.put(input_data)
except ValueError:
print("输入数据格式不正确,请输入数值。")
print("输入格式不正确,请输入数值。")
def main():
@ -88,7 +88,6 @@ def main():
args.random_seed = seed
args.agent = AgentPPO()
args.agent.cri_target = True
agent_name = f'{args.agent.__class__.__name__}'
args.env = WgzGym()
args.init_before_training()
@ -97,9 +96,7 @@ def main():
env.TRAIN = False
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
if args.service:
args.cwd = agent_name
act_save_path = f'{args.cwd}/actor.pth'
act_save_path = './data/actor.pth'
agent.act.load_state_dict(torch.load(act_save_path))
# 创建一个队列用于线程间通信

View File

@ -57,7 +57,7 @@ class WgzGym(gym.Env):
def _build_state(self):
hst_soc = self.HST.current_soc
ec_out = self.EC.get_hydrogen()
grid_ex = self.grid.trade_energy
# grid_ex = self.grid.trade_energy
time_step = self.current_time
if self.TRAIN:
@ -121,9 +121,9 @@ class WgzGym(gym.Env):
economic_cost = hst_cost + ec_cost + solar_cost - sell_benefit + buy_cost
demand_cost = self.heat_a * heat_penalty + self.power_a * power_penalty
eco_benifit = self.EC.less_carbon() - self.grid.get_carbon(power_gap)
reward = - self.a * demand_cost - self.b * economic_cost + self.c * eco_benifit
reward = (- self.a * demand_cost - self.b * economic_cost + self.c * eco_benifit) / 1e3
self.unbalance = power_gap + heat_gap
self.unbalance = (power_gap + heat_gap) / 1e3
final_step_outputs = [self.HST.current_soc, self.HST.get_power(), self.EC.current_power]
self.current_time += 1
finish = (self.current_time == self.episode_length)

View File

@ -83,15 +83,15 @@ class Grid:
self.carbon_increace = 0.9
# self.trade_energy = None
# def step(self, action_grid, ec_power_max):
# self.trade_energy = (action_grid + 1) / 2 * ec_power_max # 反标准化
def get_cost(self, price, trade_energy):
return price * trade_energy * self.delta
def get_carbon(self, trade_energy):
return trade_energy * self.carbon_increace
# def step(self, action_grid, ec_power_max):
# self.trade_energy = (action_grid + 1) / 2 * ec_power_max # 反标准化
def retrieve_past_price(self):
result = []
# 过去24小时的价格起始、结束索引

View File

@ -171,8 +171,6 @@ class Arguments:
def __init__(self, agent=None, env=None):
self.agent = agent
self.env = env
self.cwd = None # current work directory. None means set automatically
self.if_remove = False # remove the cwd folder? (True, False, None:ask me)
self.visible_gpu = '0' # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 2,'
self.num_threads = 32 # cpu_num for evaluate model
@ -194,14 +192,8 @@ class Arguments:
self.random_seed_list = [1234]
self.train = True
self.save_network = True
self.test_network = True
self.save_test_data = True
def init_before_training(self):
if self.cwd is None:
agent_name = self.agent.__class__.__name__
self.cwd = f'./{agent_name}'
np.random.seed(self.random_seed)
torch.manual_seed(self.random_seed)
torch.set_num_threads(self.num_threads)
@ -217,7 +209,6 @@ if __name__ == '__main__':
for seed in args.random_seed_list:
args.random_seed = seed
args.agent = AgentPPO()
agent_name = f'{args.agent.__class__.__name__}'
args.agent.cri_target = True
args.env = WgzGym()
args.init_before_training()
@ -226,9 +217,9 @@ if __name__ == '__main__':
env = args.env
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
gamma = args.gamma
batch_size = args.batch_size # data used to update net
target_step = args.target_step # steps of one episode should stop
repeat_times = args.repeat_times # times should update for one batch size data
batch_size = args.batch_size
target_step = args.target_step
repeat_times = args.repeat_times
soft_update_tau = args.soft_update_tau
num_episode = args.num_episode
agent.state = env.reset()
@ -254,9 +245,9 @@ if __name__ == '__main__':
reward_record['unbalance'].append(episode_unbalance)
print(f'epsiode: {i_episode}, reward: {episode_reward}, unbalance: {episode_unbalance}')
act_save_path = f'{args.cwd}/actor.pth'
loss_record_path = f'{args.cwd}/loss.pkl'
reward_record_path = f'{args.cwd}/reward.pkl'
act_save_path = './data/actor.pth'
loss_record_path = './data/loss.pkl'
reward_record_path = './data/reward.pkl'
if args.save_network:
with open(loss_record_path, 'wb') as tf: