35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
import re
|
|
import numpy as np
|
|
|
|
|
|
# 定义用于提取并处理最后x个数据平均值和方差的函数
|
|
def extract_and_calculate_stats(text, pattern=r"reward: ([\-\d\.]+),\s*unbalance:\s*([\d\.]+)"):
|
|
matches = re.findall(pattern, text)
|
|
rewards = [float(match[0]) for match in matches]
|
|
unbalances = [float(match[1]) / 1000 for match in matches]
|
|
num = -50
|
|
# 计算最后x个数据的平均值和方差
|
|
reward_mean, reward_var = np.mean(rewards[num:]), np.var(rewards[num:])
|
|
unbalance_mean, unbalance_var = np.mean(unbalances[num:]), np.var(unbalances[num:])
|
|
|
|
return reward_mean, reward_var, unbalance_mean, unbalance_var
|
|
|
|
|
|
# 文件路径和编码方式
|
|
file_paths = {
|
|
"DDPG": ('res/ddpg.txt', 'utf-16'),
|
|
"TD3": ('res/td3.txt', 'utf-8'),
|
|
"SAC": ('res/sac.txt', 'utf-8'),
|
|
"PPO": ('res/ppo_1016.txt', 'utf-8'),
|
|
"PPO_LLM": ('res/ppo_llm.txt', 'utf-8'),
|
|
}
|
|
|
|
# 遍历文件路径和编码,计算并打印每个算法的统计信息
|
|
for algo, (file_path, encoding) in file_paths.items():
|
|
with open(file_path, 'r', encoding=encoding) as f:
|
|
text = f.read()
|
|
|
|
reward_mean, reward_var, unbalance_mean, unbalance_var = extract_and_calculate_stats(text)
|
|
print(f"{algo} - Reward Mean: {reward_mean:.4f}, Variance: {reward_var:.4f}")
|
|
print(f"{algo} - Unbalance Mean: {unbalance_mean:.4f}, Variance: {unbalance_var:.4f}")
|