first infer
This commit is contained in:
parent
fb89548c37
commit
08d77eb79e
122
inference.py
122
inference.py
|
@ -0,0 +1,122 @@
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from train import *
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_step(env, act, device, data, action_path):
|
||||||
|
env.rec_data = data
|
||||||
|
state = env.reset()
|
||||||
|
s_tensor = torch.as_tensor((state,), device=device)
|
||||||
|
a_tensor = act(s_tensor)
|
||||||
|
action = a_tensor.detach().cpu().numpy()[0]
|
||||||
|
state, next_state, reward, done = env.step(action)
|
||||||
|
print(f'The action of {env.current_time} is {action}')
|
||||||
|
|
||||||
|
with open(action_path, 'a') as af:
|
||||||
|
af.write(f'{env.current_time},{action}\n')
|
||||||
|
return reward, env.real_unbalance
|
||||||
|
|
||||||
|
|
||||||
|
def run_service_test(env, agent, data):
|
||||||
|
service_result_path = 'data/service_result.csv'
|
||||||
|
action_path = 'data/service_actions.csv'
|
||||||
|
|
||||||
|
if not os.path.exists(service_result_path):
|
||||||
|
with open(service_result_path, 'w') as f:
|
||||||
|
f.write('reward,unbalance\n')
|
||||||
|
|
||||||
|
if not os.path.exists(action_path):
|
||||||
|
with open(action_path, 'w') as af:
|
||||||
|
af.write('time,action\n')
|
||||||
|
|
||||||
|
service_rewards = []
|
||||||
|
service_unbalances = []
|
||||||
|
|
||||||
|
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
|
||||||
|
service_rewards.append(service_reward)
|
||||||
|
service_unbalances.append(service_unbalance)
|
||||||
|
|
||||||
|
if service_rewards:
|
||||||
|
avg_reward = sum(service_rewards) / len(service_rewards)
|
||||||
|
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
|
||||||
|
|
||||||
|
with open(service_result_path, 'a') as f:
|
||||||
|
f.write(f'{avg_reward},{avg_unbalance}\n')
|
||||||
|
|
||||||
|
|
||||||
|
# 接听端
|
||||||
|
def listener_thread(env, agent, data_queue):
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1) # 等待
|
||||||
|
if not data_queue.empty():
|
||||||
|
new_data = data_queue.get()
|
||||||
|
print(f"Data received: {new_data}")
|
||||||
|
run_service_test(env, agent, new_data)
|
||||||
|
data_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
# 发送端
|
||||||
|
def sender_thread(data_queue):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("请输入price, temper, solar, load, heat, people(用逗号分隔): ")
|
||||||
|
|
||||||
|
# 将输入字符串分割并转换为浮点数列表
|
||||||
|
input_data = list(map(float, user_input.split(',')))
|
||||||
|
|
||||||
|
# 检查输入是否包含六个数值
|
||||||
|
if len(input_data) != 6:
|
||||||
|
print("输入数据格式不正确,请输入六个数值。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 将数据放入队列
|
||||||
|
print(f"Sending data: {input_data}")
|
||||||
|
data_queue.put(input_data)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print("输入数据格式不正确,请输入数值。")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = Arguments()
|
||||||
|
args.visible_gpu = '0'
|
||||||
|
for seed in args.random_seed_list:
|
||||||
|
args.random_seed = seed
|
||||||
|
args.agent = AgentPPO()
|
||||||
|
args.agent.cri_target = True
|
||||||
|
agent_name = f'{args.agent.__class__.__name__}'
|
||||||
|
args.env = WgzGym()
|
||||||
|
args.init_before_training()
|
||||||
|
|
||||||
|
agent = args.agent
|
||||||
|
env = args.env
|
||||||
|
env.TRAIN = False
|
||||||
|
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
||||||
|
|
||||||
|
if args.service:
|
||||||
|
args.cwd = agent_name
|
||||||
|
service_act_save_path = f'{args.cwd}/actor.pth'
|
||||||
|
agent.act.load_state_dict(torch.load(service_act_save_path))
|
||||||
|
|
||||||
|
# 创建一个队列用于线程间通信
|
||||||
|
data_queue = queue.Queue()
|
||||||
|
|
||||||
|
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
|
||||||
|
listener.daemon = True
|
||||||
|
listener.start()
|
||||||
|
|
||||||
|
sender = threading.Thread(target=sender_thread, args=(data_queue,))
|
||||||
|
sender.daemon = True
|
||||||
|
sender.start()
|
||||||
|
|
||||||
|
# 主线程保持运行,等待数据传递
|
||||||
|
while True:
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -9,6 +9,7 @@ from parameters import *
|
||||||
class WgzGym(gym.Env):
|
class WgzGym(gym.Env):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(WgzGym, self).__init__()
|
super(WgzGym, self).__init__()
|
||||||
|
self.rec_data = None
|
||||||
self.unbalance = None
|
self.unbalance = None
|
||||||
self.reward = None
|
self.reward = None
|
||||||
self.current_output = None
|
self.current_output = None
|
||||||
|
@ -58,12 +59,20 @@ class WgzGym(gym.Env):
|
||||||
grid_ex = self.grid
|
grid_ex = self.grid
|
||||||
time_step = self.current_time
|
time_step = self.current_time
|
||||||
|
|
||||||
price = self.data_manager.get_price_data(self.month, self.day, self.current_time)
|
if self.TRAIN:
|
||||||
temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time)
|
price = self.data_manager.get_price_data(self.month, self.day, self.current_time)
|
||||||
solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time)
|
temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time)
|
||||||
load = self.data_manager.get_load_data(self.month, self.day, self.current_time)
|
solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time)
|
||||||
heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time)
|
load = self.data_manager.get_load_data(self.month, self.day, self.current_time)
|
||||||
people = self.data_manager.get_people_data(self.month, self.day, self.current_time)
|
heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time)
|
||||||
|
people = self.data_manager.get_people_data(self.month, self.day, self.current_time)
|
||||||
|
else:
|
||||||
|
price = self.rec_data[0]
|
||||||
|
temper = self.rec_data[1]
|
||||||
|
solar = self.rec_data[2]
|
||||||
|
load = self.rec_data[3]
|
||||||
|
heat = self.rec_data[4]
|
||||||
|
people = self.rec_data[5]
|
||||||
|
|
||||||
obs = np.concatenate((np.float32(time_step), np.float32(price), np.float32(temper),
|
obs = np.concatenate((np.float32(time_step), np.float32(price), np.float32(temper),
|
||||||
np.float32(solar), np.float32(load), np.float32(heat),
|
np.float32(solar), np.float32(load), np.float32(heat),
|
||||||
|
@ -78,10 +87,11 @@ class WgzGym(gym.Env):
|
||||||
self.HST.step(action[1])
|
self.HST.step(action[1])
|
||||||
self.grid.step(action[2])
|
self.grid.step(action[2])
|
||||||
price = current_obs[1]
|
price = current_obs[1]
|
||||||
temper = current_obs[2]
|
temper = current_obs[2] # 用途待补充
|
||||||
solar = current_obs[3]
|
solar = current_obs[3]
|
||||||
load = current_obs[4]
|
load = current_obs[4]
|
||||||
heat = current_obs[5]
|
heat = current_obs[5]
|
||||||
|
people = current_obs[6] # 用途待补充
|
||||||
|
|
||||||
power_gap = solar + self.HST.get_power() - self.EC.current_power - load
|
power_gap = solar + self.HST.get_power() - self.EC.current_power - load
|
||||||
heat_gap = self.HST.get_heat() + self.EC.get_heat() - heat
|
heat_gap = self.HST.get_heat() + self.EC.get_heat() - heat
|
||||||
|
|
Loading…
Reference in New Issue