2025-02-17 09:29:08 +08:00
|
|
|
|
import queue
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from train import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_one_step(env, act, device, data, action_path):
|
|
|
|
|
env.rec_data = data
|
|
|
|
|
state = env.reset()
|
|
|
|
|
s_tensor = torch.as_tensor((state,), device=device)
|
|
|
|
|
a_tensor = act(s_tensor)
|
|
|
|
|
action = a_tensor.detach().cpu().numpy()[0]
|
|
|
|
|
state, next_state, reward, done = env.step(action)
|
|
|
|
|
print(f'The action of {env.current_time} is {action}')
|
|
|
|
|
|
|
|
|
|
with open(action_path, 'a') as af:
|
|
|
|
|
af.write(f'{env.current_time},{action}\n')
|
2025-02-18 08:49:32 +08:00
|
|
|
|
return reward, env.unbalance
|
2025-02-17 09:29:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_service_test(env, agent, data):
|
|
|
|
|
service_result_path = 'data/service_result.csv'
|
|
|
|
|
action_path = 'data/service_actions.csv'
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(service_result_path):
|
|
|
|
|
with open(service_result_path, 'w') as f:
|
|
|
|
|
f.write('reward,unbalance\n')
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(action_path):
|
|
|
|
|
with open(action_path, 'w') as af:
|
|
|
|
|
af.write('time,action\n')
|
|
|
|
|
|
|
|
|
|
service_rewards = []
|
|
|
|
|
service_unbalances = []
|
|
|
|
|
|
|
|
|
|
service_reward, service_unbalance = test_one_step(env, agent.act, agent.device, data, action_path)
|
|
|
|
|
service_rewards.append(service_reward)
|
|
|
|
|
service_unbalances.append(service_unbalance)
|
|
|
|
|
|
|
|
|
|
if service_rewards:
|
|
|
|
|
avg_reward = sum(service_rewards) / len(service_rewards)
|
|
|
|
|
avg_unbalance = sum(service_unbalances) / len(service_unbalances)
|
|
|
|
|
|
|
|
|
|
with open(service_result_path, 'a') as f:
|
|
|
|
|
f.write(f'{avg_reward},{avg_unbalance}\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 接听端
|
|
|
|
|
def listener_thread(env, agent, data_queue):
|
|
|
|
|
while True:
|
|
|
|
|
time.sleep(0.1) # 等待
|
|
|
|
|
if not data_queue.empty():
|
|
|
|
|
new_data = data_queue.get()
|
|
|
|
|
print(f"Data received: {new_data}")
|
|
|
|
|
run_service_test(env, agent, new_data)
|
|
|
|
|
data_queue.task_done()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 发送端
|
|
|
|
|
def sender_thread(data_queue):
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
2025-02-18 08:49:32 +08:00
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
user_input = input("请输入当前时刻的price, temper, solar, load, heat, people(用逗号分隔): \n")
|
2025-02-17 09:29:08 +08:00
|
|
|
|
|
|
|
|
|
# 将输入字符串分割并转换为浮点数列表
|
|
|
|
|
input_data = list(map(float, user_input.split(',')))
|
|
|
|
|
|
|
|
|
|
# 检查输入是否包含六个数值
|
|
|
|
|
if len(input_data) != 6:
|
2025-02-18 08:49:32 +08:00
|
|
|
|
print("输入格式不正确,请输入六个数值。")
|
2025-02-17 09:29:08 +08:00
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 将数据放入队列
|
|
|
|
|
print(f"Sending data: {input_data}")
|
|
|
|
|
data_queue.put(input_data)
|
|
|
|
|
|
|
|
|
|
except ValueError:
|
2025-02-18 08:49:32 +08:00
|
|
|
|
print("输入格式不正确,请输入数值。")
|
2025-02-17 09:29:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
args = Arguments()
|
|
|
|
|
args.visible_gpu = '0'
|
|
|
|
|
for seed in args.random_seed_list:
|
|
|
|
|
args.random_seed = seed
|
|
|
|
|
args.agent = AgentPPO()
|
|
|
|
|
args.agent.cri_target = True
|
|
|
|
|
args.env = WgzGym()
|
|
|
|
|
args.init_before_training()
|
|
|
|
|
|
|
|
|
|
agent = args.agent
|
|
|
|
|
env = args.env
|
|
|
|
|
env.TRAIN = False
|
|
|
|
|
agent.init(args.net_dim, env.state_space.shape[0], env.action_space.shape[0], args.learning_rate)
|
|
|
|
|
|
2025-02-18 08:49:32 +08:00
|
|
|
|
act_save_path = './data/actor.pth'
|
2025-02-17 09:45:15 +08:00
|
|
|
|
agent.act.load_state_dict(torch.load(act_save_path))
|
2025-02-17 09:29:08 +08:00
|
|
|
|
|
|
|
|
|
# 创建一个队列用于线程间通信
|
|
|
|
|
data_queue = queue.Queue()
|
|
|
|
|
|
|
|
|
|
listener = threading.Thread(target=listener_thread, args=(env, agent, data_queue))
|
|
|
|
|
listener.daemon = True
|
|
|
|
|
listener.start()
|
|
|
|
|
|
|
|
|
|
sender = threading.Thread(target=sender_thread, args=(data_queue,))
|
|
|
|
|
sender.daemon = True
|
|
|
|
|
sender.start()
|
|
|
|
|
|
|
|
|
|
# 主线程保持运行,等待数据传递
|
|
|
|
|
while True:
|
|
|
|
|
time.sleep(10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|