refactor(train)
This commit is contained in:
parent
d0ea09d663
commit
a3ad4182f1
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,5 @@
|
|||
time,action
|
||||
1,[-0.85844654 -0.913628 ]
|
||||
1,[-0.97137856 -0.9997079 ]
|
||||
1,[-0.97137856 -0.9997079 ]
|
||||
1,[-0.97137856 -0.9997079 ]
|
|
|
@ -0,0 +1,4 @@
|
|||
reward,unbalance
|
||||
-0.09503999999999999,0.03
|
||||
-0.09503999999999999,0.03
|
||||
-0.09503999999999999,0.03
|
|
15
inference.py
15
inference.py
|
@ -1,7 +1,6 @@
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from train import *
|
||||
|
@ -18,7 +17,7 @@ def test_one_step(env, act, device, data, action_path):
|
|||
|
||||
with open(action_path, 'a') as af:
|
||||
af.write(f'{env.current_time},{action}\n')
|
||||
return reward, env.real_unbalance
|
||||
return reward, env.unbalance
|
||||
|
||||
|
||||
def run_service_test(env, agent, data):
|
||||
|
@ -63,14 +62,15 @@ def listener_thread(env, agent, data_queue):
|
|||
def sender_thread(data_queue):
|
||||
while True:
|
||||
try:
|
||||
user_input = input("请输入price, temper, solar, load, heat, people(用逗号分隔): ")
|
||||
time.sleep(0.5)
|
||||
user_input = input("请输入当前时刻的price, temper, solar, load, heat, people(用逗号分隔): \n")
|
||||
|
||||
# 将输入字符串分割并转换为浮点数列表
|
||||
input_data = list(map(float, user_input.split(',')))
|
||||
|
||||
# 检查输入是否包含六个数值
|
||||
if len(input_data) != 6:
|
||||
print("输入数据格式不正确,请输入六个数值。")
|
||||
print("输入格式不正确,请输入六个数值。")
|
||||
continue
|
||||
|
||||
# 将数据放入队列
|
||||
|
@ -78,7 +78,7 @@ def sender_thread(data_queue):
|
|||
data_queue.put(input_data)
|
||||
|
||||
except ValueError:
|
||||
print("输入数据格式不正确,请输入数值。")
|
||||
print("输入格式不正确,请输入数值。")
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -88,7 +88,6 @@ def main():
|
|||
args.random_seed = seed
|
||||
args.agent = AgentPPO()
|
||||
args.agent.cri_target = True
|
||||
agent_name = f'{args.agent.__class__.__name__}'
|
||||
args.env = WgzGym()
|
||||
args.init_before_training()
|
||||
|
||||
|
@ -97,9 +96,7 @@ def main():
|
|||
env.TRAIN = False
|
||||
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
||||
|
||||
if args.service:
|
||||
args.cwd = agent_name
|
||||
act_save_path = f'{args.cwd}/actor.pth'
|
||||
act_save_path = './data/actor.pth'
|
||||
agent.act.load_state_dict(torch.load(act_save_path))
|
||||
|
||||
# 创建一个队列用于线程间通信
|
||||
|
|
|
@ -57,7 +57,7 @@ class WgzGym(gym.Env):
|
|||
def _build_state(self):
|
||||
hst_soc = self.HST.current_soc
|
||||
ec_out = self.EC.get_hydrogen()
|
||||
grid_ex = self.grid.trade_energy
|
||||
# grid_ex = self.grid.trade_energy
|
||||
time_step = self.current_time
|
||||
|
||||
if self.TRAIN:
|
||||
|
@ -121,9 +121,9 @@ class WgzGym(gym.Env):
|
|||
economic_cost = hst_cost + ec_cost + solar_cost - sell_benefit + buy_cost
|
||||
demand_cost = self.heat_a * heat_penalty + self.power_a * power_penalty
|
||||
eco_benifit = self.EC.less_carbon() - self.grid.get_carbon(power_gap)
|
||||
reward = - self.a * demand_cost - self.b * economic_cost + self.c * eco_benifit
|
||||
reward = (- self.a * demand_cost - self.b * economic_cost + self.c * eco_benifit) / 1e3
|
||||
|
||||
self.unbalance = power_gap + heat_gap
|
||||
self.unbalance = (power_gap + heat_gap) / 1e3
|
||||
final_step_outputs = [self.HST.current_soc, self.HST.get_power(), self.EC.current_power]
|
||||
self.current_time += 1
|
||||
finish = (self.current_time == self.episode_length)
|
||||
|
|
|
@ -83,15 +83,15 @@ class Grid:
|
|||
self.carbon_increace = 0.9
|
||||
# self.trade_energy = None
|
||||
|
||||
# def step(self, action_grid, ec_power_max):
|
||||
# self.trade_energy = (action_grid + 1) / 2 * ec_power_max # 反标准化
|
||||
|
||||
def get_cost(self, price, trade_energy):
|
||||
return price * trade_energy * self.delta
|
||||
|
||||
def get_carbon(self, trade_energy):
|
||||
return trade_energy * self.carbon_increace
|
||||
|
||||
# def step(self, action_grid, ec_power_max):
|
||||
# self.trade_energy = (action_grid + 1) / 2 * ec_power_max # 反标准化
|
||||
|
||||
def retrieve_past_price(self):
|
||||
result = []
|
||||
# 过去24小时的价格起始、结束索引
|
||||
|
|
21
train.py
21
train.py
|
@ -171,8 +171,6 @@ class Arguments:
|
|||
def __init__(self, agent=None, env=None):
|
||||
self.agent = agent
|
||||
self.env = env
|
||||
self.cwd = None # current work directory. None means set automatically
|
||||
self.if_remove = False # remove the cwd folder? (True, False, None:ask me)
|
||||
self.visible_gpu = '0' # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 2,'
|
||||
self.num_threads = 32 # cpu_num for evaluate model
|
||||
|
||||
|
@ -194,14 +192,8 @@ class Arguments:
|
|||
self.random_seed_list = [1234]
|
||||
self.train = True
|
||||
self.save_network = True
|
||||
self.test_network = True
|
||||
self.save_test_data = True
|
||||
|
||||
def init_before_training(self):
|
||||
if self.cwd is None:
|
||||
agent_name = self.agent.__class__.__name__
|
||||
self.cwd = f'./{agent_name}'
|
||||
|
||||
np.random.seed(self.random_seed)
|
||||
torch.manual_seed(self.random_seed)
|
||||
torch.set_num_threads(self.num_threads)
|
||||
|
@ -217,7 +209,6 @@ if __name__ == '__main__':
|
|||
for seed in args.random_seed_list:
|
||||
args.random_seed = seed
|
||||
args.agent = AgentPPO()
|
||||
agent_name = f'{args.agent.__class__.__name__}'
|
||||
args.agent.cri_target = True
|
||||
args.env = WgzGym()
|
||||
args.init_before_training()
|
||||
|
@ -226,9 +217,9 @@ if __name__ == '__main__':
|
|||
env = args.env
|
||||
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
||||
gamma = args.gamma
|
||||
batch_size = args.batch_size # data used to update net
|
||||
target_step = args.target_step # steps of one episode should stop
|
||||
repeat_times = args.repeat_times # times should update for one batch size data
|
||||
batch_size = args.batch_size
|
||||
target_step = args.target_step
|
||||
repeat_times = args.repeat_times
|
||||
soft_update_tau = args.soft_update_tau
|
||||
num_episode = args.num_episode
|
||||
agent.state = env.reset()
|
||||
|
@ -254,9 +245,9 @@ if __name__ == '__main__':
|
|||
reward_record['unbalance'].append(episode_unbalance)
|
||||
print(f'epsiode: {i_episode}, reward: {episode_reward}, unbalance: {episode_unbalance}')
|
||||
|
||||
act_save_path = f'{args.cwd}/actor.pth'
|
||||
loss_record_path = f'{args.cwd}/loss.pkl'
|
||||
reward_record_path = f'{args.cwd}/reward.pkl'
|
||||
act_save_path = './data/actor.pth'
|
||||
loss_record_path = './data/loss.pkl'
|
||||
reward_record_path = './data/reward.pkl'
|
||||
|
||||
if args.save_network:
|
||||
with open(loss_record_path, 'wb') as tf:
|
||||
|
|
Loading…
Reference in New Issue