120 lines
3.5 KiB
Python
120 lines
3.5 KiB
Python
import queue
|
||
import threading
|
||
import time
|
||
import torch
|
||
|
||
from train import *
|
||
|
||
|
||
def test_one_step(env, act, device, data, action_path):
|
||
env.rec_data = data
|
||
state = env.reset()
|
||
s_tensor = torch.as_tensor((state,), device=device)
|
||
a_tensor = act(s_tensor)
|
||
action = a_tensor.detach().cpu().numpy()[0]
|
||
state, next_state, reward, done = env.step(action)
|
||
print(f'The action of {env.current_time} is {action}')
|
||
|
||
with open(action_path, 'a') as af:
|
||
af.write(f'{env.current_time},{action}\n')
|
||
return reward, env.unbalance
|
||
|
||
|
||
def run_service_test(env, agent, data):
|
||
service_result_path = 'data/service_result.csv'
|
||
action_path = 'data/service_actions.csv'
|
||
|
||
if not os.path.exists(service_result_path):
|
||
with open(service_result_path, 'w') as f:
|
||
f.write('reward,unbalance\n')
|
||
|
||
if not os.path.exists(action_path):
|
||
with open(action_path, 'w') as af:
|
||
af.write('time,action\n')
|
||
|
||
service_rewards = []
|
||
service_unbalances = []
|
||
|
||
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
|
||
service_rewards.append(service_reward)
|
||
service_unbalances.append(service_unbalance)
|
||
|
||
if service_rewards:
|
||
avg_reward = sum(service_rewards) / len(service_rewards)
|
||
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
|
||
|
||
with open(service_result_path, 'a') as f:
|
||
f.write(f'{avg_reward},{avg_unbalance}\n')
|
||
|
||
|
||
# 接听端
|
||
def listener_thread(env, agent, data_queue):
|
||
while True:
|
||
time.sleep(0.1) # 等待
|
||
if not data_queue.empty():
|
||
new_data = data_queue.get()
|
||
print(f"Data received: {new_data}")
|
||
run_service_test(env, agent, new_data)
|
||
data_queue.task_done()
|
||
|
||
|
||
# 发送端
|
||
def sender_thread(data_queue):
|
||
while True:
|
||
try:
|
||
time.sleep(0.5)
|
||
user_input = input("请输入当前时刻的price, temper, solar, load, heat, people(用逗号分隔): \n")
|
||
|
||
# 将输入字符串分割并转换为浮点数列表
|
||
input_data = list(map(float, user_input.split(',')))
|
||
|
||
# 检查输入是否包含六个数值
|
||
if len(input_data) != 6:
|
||
print("输入格式不正确,请输入六个数值。")
|
||
continue
|
||
|
||
# 将数据放入队列
|
||
print(f"Sending data: {input_data}")
|
||
data_queue.put(input_data)
|
||
|
||
except ValueError:
|
||
print("输入格式不正确,请输入数值。")
|
||
|
||
|
||
def main():
|
||
args = Arguments()
|
||
args.visible_gpu = '0'
|
||
for seed in args.random_seed_list:
|
||
args.random_seed = seed
|
||
args.agent = AgentPPO()
|
||
args.agent.cri_target = True
|
||
args.env = WgzGym()
|
||
args.init_before_training()
|
||
|
||
agent = args.agent
|
||
env = args.env
|
||
env.TRAIN = False
|
||
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
||
|
||
act_save_path = './data/actor.pth'
|
||
agent.act.load_state_dict(torch.load(act_save_path))
|
||
|
||
# 创建一个队列用于线程间通信
|
||
data_queue = queue.Queue()
|
||
|
||
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
|
||
listener.daemon = True
|
||
listener.start()
|
||
|
||
sender = threading.Thread(target=sender_thread, args=(data_queue,))
|
||
sender.daemon = True
|
||
sender.start()
|
||
|
||
# 主线程保持运行,等待数据传递
|
||
while True:
|
||
time.sleep(10)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|