building-agents/measure.py

35 lines
1.3 KiB
Python
Raw Normal View History

2024-11-22 10:03:31 +08:00
import re
import numpy as np
# 定义用于提取并处理最后x个数据平均值和方差的函数
def extract_and_calculate_stats(text, pattern=r"reward: ([\-\d\.]+),\s*unbalance:\s*([\d\.]+)"):
matches = re.findall(pattern, text)
rewards = [float(match[0]) for match in matches]
unbalances = [float(match[1]) / 1000 for match in matches]
num = -50
# 计算最后x个数据的平均值和方差
reward_mean, reward_var = np.mean(rewards[num:]), np.var(rewards[num:])
unbalance_mean, unbalance_var = np.mean(unbalances[num:]), np.var(unbalances[num:])
return reward_mean, reward_var, unbalance_mean, unbalance_var
# 文件路径和编码方式
file_paths = {
"DDPG": ('res/ddpg.txt', 'utf-16'),
"TD3": ('res/td3.txt', 'utf-8'),
"SAC": ('res/sac.txt', 'utf-8'),
"PPO": ('res/ppo_1016.txt', 'utf-8'),
"PPO_LLM": ('res/ppo_llm.txt', 'utf-8'),
}
# 遍历文件路径和编码,计算并打印每个算法的统计信息
for algo, (file_path, encoding) in file_paths.items():
with open(file_path, 'r', encoding=encoding) as f:
text = f.read()
reward_mean, reward_var, unbalance_mean, unbalance_var = extract_and_calculate_stats(text)
print(f"{algo} - Reward Mean: {reward_mean:.4f}, Variance: {reward_var:.4f}")
print(f"{algo} - Unbalance Mean: {unbalance_mean:.4f}, Variance: {unbalance_var:.4f}")