first infer
This commit is contained in:
parent
fb89548c37
commit
08d77eb79e
122
inference.py
122
inference.py
|
@ -0,0 +1,122 @@
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from train import *
|
||||
|
||||
|
||||
def test_one_step(env, act, device, data, action_path):
|
||||
env.rec_data = data
|
||||
state = env.reset()
|
||||
s_tensor = torch.as_tensor((state,), device=device)
|
||||
a_tensor = act(s_tensor)
|
||||
action = a_tensor.detach().cpu().numpy()[0]
|
||||
state, next_state, reward, done = env.step(action)
|
||||
print(f'The action of {env.current_time} is {action}')
|
||||
|
||||
with open(action_path, 'a') as af:
|
||||
af.write(f'{env.current_time},{action}\n')
|
||||
return reward, env.real_unbalance
|
||||
|
||||
|
||||
def run_service_test(env, agent, data):
|
||||
service_result_path = 'data/service_result.csv'
|
||||
action_path = 'data/service_actions.csv'
|
||||
|
||||
if not os.path.exists(service_result_path):
|
||||
with open(service_result_path, 'w') as f:
|
||||
f.write('reward,unbalance\n')
|
||||
|
||||
if not os.path.exists(action_path):
|
||||
with open(action_path, 'w') as af:
|
||||
af.write('time,action\n')
|
||||
|
||||
service_rewards = []
|
||||
service_unbalances = []
|
||||
|
||||
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
|
||||
service_rewards.append(service_reward)
|
||||
service_unbalances.append(service_unbalance)
|
||||
|
||||
if service_rewards:
|
||||
avg_reward = sum(service_rewards) / len(service_rewards)
|
||||
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
|
||||
|
||||
with open(service_result_path, 'a') as f:
|
||||
f.write(f'{avg_reward},{avg_unbalance}\n')
|
||||
|
||||
|
||||
# 接听端
|
||||
def listener_thread(env, agent, data_queue):
|
||||
while True:
|
||||
time.sleep(0.1) # 等待
|
||||
if not data_queue.empty():
|
||||
new_data = data_queue.get()
|
||||
print(f"Data received: {new_data}")
|
||||
run_service_test(env, agent, new_data)
|
||||
data_queue.task_done()
|
||||
|
||||
|
||||
# 发送端
|
||||
def sender_thread(data_queue):
|
||||
while True:
|
||||
try:
|
||||
user_input = input("请输入price, temper, solar, load, heat, people(用逗号分隔): ")
|
||||
|
||||
# 将输入字符串分割并转换为浮点数列表
|
||||
input_data = list(map(float, user_input.split(',')))
|
||||
|
||||
# 检查输入是否包含六个数值
|
||||
if len(input_data) != 6:
|
||||
print("输入数据格式不正确,请输入六个数值。")
|
||||
continue
|
||||
|
||||
# 将数据放入队列
|
||||
print(f"Sending data: {input_data}")
|
||||
data_queue.put(input_data)
|
||||
|
||||
except ValueError:
|
||||
print("输入数据格式不正确,请输入数值。")
|
||||
|
||||
|
||||
def main():
|
||||
args = Arguments()
|
||||
args.visible_gpu = '0'
|
||||
for seed in args.random_seed_list:
|
||||
args.random_seed = seed
|
||||
args.agent = AgentPPO()
|
||||
args.agent.cri_target = True
|
||||
agent_name = f'{args.agent.__class__.__name__}'
|
||||
args.env = WgzGym()
|
||||
args.init_before_training()
|
||||
|
||||
agent = args.agent
|
||||
env = args.env
|
||||
env.TRAIN = False
|
||||
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
||||
|
||||
if args.service:
|
||||
args.cwd = agent_name
|
||||
service_act_save_path = f'{args.cwd}/actor.pth'
|
||||
agent.act.load_state_dict(torch.load(service_act_save_path))
|
||||
|
||||
# 创建一个队列用于线程间通信
|
||||
data_queue = queue.Queue()
|
||||
|
||||
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
|
||||
listener.daemon = True
|
||||
listener.start()
|
||||
|
||||
sender = threading.Thread(target=sender_thread, args=(data_queue,))
|
||||
sender.daemon = True
|
||||
sender.start()
|
||||
|
||||
# 主线程保持运行,等待数据传递
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -9,6 +9,7 @@ from parameters import *
|
|||
class WgzGym(gym.Env):
|
||||
def __init__(self, **kwargs):
|
||||
super(WgzGym, self).__init__()
|
||||
self.rec_data = None
|
||||
self.unbalance = None
|
||||
self.reward = None
|
||||
self.current_output = None
|
||||
|
@ -58,12 +59,20 @@ class WgzGym(gym.Env):
|
|||
grid_ex = self.grid
|
||||
time_step = self.current_time
|
||||
|
||||
if self.TRAIN:
|
||||
price = self.data_manager.get_price_data(self.month, self.day, self.current_time)
|
||||
temper = self.data_manager.get_temperature_data(self.month, self.day, self.current_time)
|
||||
solar = self.data_manager.get_solar_data(self.month, self.day, self.current_time)
|
||||
load = self.data_manager.get_load_data(self.month, self.day, self.current_time)
|
||||
heat = self.data_manager.get_heat_data(self.month, self.day, self.current_time)
|
||||
people = self.data_manager.get_people_data(self.month, self.day, self.current_time)
|
||||
else:
|
||||
price = self.rec_data[0]
|
||||
temper = self.rec_data[1]
|
||||
solar = self.rec_data[2]
|
||||
load = self.rec_data[3]
|
||||
heat = self.rec_data[4]
|
||||
people = self.rec_data[5]
|
||||
|
||||
obs = np.concatenate((np.float32(time_step), np.float32(price), np.float32(temper),
|
||||
np.float32(solar), np.float32(load), np.float32(heat),
|
||||
|
@ -78,10 +87,11 @@ class WgzGym(gym.Env):
|
|||
self.HST.step(action[1])
|
||||
self.grid.step(action[2])
|
||||
price = current_obs[1]
|
||||
temper = current_obs[2]
|
||||
temper = current_obs[2] # 用途待补充
|
||||
solar = current_obs[3]
|
||||
load = current_obs[4]
|
||||
heat = current_obs[5]
|
||||
people = current_obs[6] # 用途待补充
|
||||
|
||||
power_gap = solar + self.HST.get_power() - self.EC.current_power - load
|
||||
heat_gap = self.HST.get_heat() + self.EC.get_heat() - heat
|
||||
|
|
Loading…
Reference in New Issue