building-agents/plott.py

127 lines
5.3 KiB
Python
Raw Permalink Normal View History

2024-11-22 10:03:31 +08:00
import re
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from matplotlib.patches import ConnectionPatch
def extract(text):
matches = re.findall(pattern, text)
rewards = []
unbalances = []
for match in matches:
reward, unbalance = match
rewards.append(float(reward))
unbalances.append(float(unbalance) / 1000)
rewards_s = pd.Series(rewards).rolling(window=3).mean()
unbalances_s = pd.Series(unbalances).rolling(window=3).mean()
return rewards, rewards_s, unbalances, unbalances_s
pattern = r"reward: ([\-\d\.]+),\s*unbalance:\s*([\d\.]+)"
with open('res/ddpg.txt', 'r', encoding='utf-16') as f1:
ddpg = f1.read()
with open('res/ppo_1016.txt', 'r', encoding='utf-8') as f2:
ppo = f2.read()
with open('res/ppo_llm.txt', 'r', encoding='utf-8') as f3:
ppo_llm = f3.read()
with open('res/sac.txt', 'r', encoding='utf-8') as f4:
sac = f4.read()
with open('res/td3.txt', 'r', encoding='utf-8') as f5:
td3 = f5.read()
ddpg_rewards, ddpg_rewards_s, ddpg_unbalances, ddpg_unbalances_s = extract(ddpg)
td3_rewards, td3_rewards_s, td3_unbalances, td3_unbalances_s = extract(td3)
sac_rewards, sac_rewards_s, sac_unbalances, sac_unbalances_s = extract(sac)
ppo_rewards, ppo_rewards_s, ppo_unbalances, ppo_unbalances_s = extract(ppo)
lmppo_rewards, lmppo_rewards_s, lmppo_unbalances, lmppo_unbalances_s = extract(ppo_llm)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10))
ax1.plot(ddpg_rewards, color='#E03C31', alpha=0.2)
ax1.plot(td3_rewards, color='#1D8C55', alpha=0.2)
ax1.plot(sac_rewards, color='#FFA500', alpha=0.2)
ax1.plot(ppo_rewards, color='#3498DB', alpha=0.2)
ax1.plot(lmppo_rewards, color='#A52A2A', alpha=0.2)
ax1.plot(ddpg_rewards_s, label='DDPG', color='#E03C31', alpha=0.8)
ax1.plot(td3_rewards_s, label='TD3', color='#1D8C55', alpha=0.8)
ax1.plot(sac_rewards_s, label='SAC', color='#FFA500', alpha=0.8)
ax1.plot(ppo_rewards_s, label='PPO', color='#3498DB', alpha=0.8)
ax1.plot(lmppo_rewards_s, label='LMPPO', color='#A52A2A', alpha=0.8)
ax1.set_xlabel('Time [h]\n(a)', fontsize=14, labelpad=0)
ax1.set_ylabel('Reward[-]', fontsize=14)
# 子图
inset_ax1 = fig.add_axes([0.45, 0.6, 0.25, 0.15]) # [x, y, 宽, 高]
inset_ax1.plot(ddpg_rewards, color='#E03C31', alpha=0.8)
inset_ax1.plot(td3_rewards, color='#1D8C55', alpha=0.8)
inset_ax1.plot(sac_rewards, color='#FFA500', alpha=0.8)
inset_ax1.plot(ppo_rewards, color='#3498DB', alpha=0.8)
inset_ax1.plot(lmppo_rewards, color='#A52A2A', alpha=0.8)
inset_ax1.set_xlim(900, 1000)
inset_ax1.set_ylim(-150, -20)
inset_ax1.grid(True, which='both', axis='both', linestyle='--', linewidth=0.5, color='gray')
inset_ax1.tick_params(axis='both', which='major', labelsize=10)
ax1.grid(True, which='both', axis='both', linestyle='--', linewidth=0.5, color='gray')
ax1.legend(fontsize=12, handletextpad=0.5, labelspacing=0.3)
# 使用ConnectionPatch连接主图和子图
xy_main = (900, 0)
xy_inset = (900, -20)
xy_main2 = (1000, 0)
xy_inset2 = (1000, -20)
con = ConnectionPatch(xyA=xy_inset, xyB=xy_main, coordsA="data", coordsB="data",
axesA=inset_ax1, axesB=ax1, linestyle='--', linewidth=1, color='gray')
con2 = ConnectionPatch(xyA=xy_inset2, xyB=xy_main2, coordsA="data", coordsB="data",
axesA=inset_ax1, axesB=ax1, linestyle='--', linewidth=1, color='gray')
fig.add_artist(con)
fig.add_artist(con2)
ax2.plot(ddpg_unbalances, color='#E03C31', alpha=0.2)
ax2.plot(td3_unbalances, color='#1D8C55', alpha=0.2)
ax2.plot(sac_unbalances, color='#FFA500', alpha=0.2)
ax2.plot(ppo_unbalances, color='#3498DB', alpha=0.2)
ax2.plot(lmppo_unbalances, color='#A52A2A', alpha=0.2)
ax2.plot(ddpg_unbalances_s, label='DDPG', color='#E03C31', alpha=0.8)
ax2.plot(td3_unbalances_s, label='TD3', color='#1D8C55', alpha=0.8)
ax2.plot(sac_unbalances_s, label='SAC', color='#FFA500', alpha=0.8)
ax2.plot(ppo_unbalances_s, label='PPO', color='#3498DB', alpha=0.8)
ax2.plot(lmppo_unbalances_s, label='LMPPO', color='#A52A2A', alpha=0.8)
ax2.set_xlabel('Time [h]\n(b)', fontsize=14, labelpad=0)
ax2.set_ylabel('Power imbalance[KW]', fontsize=14)
# 子图
inset_ax2 = fig.add_axes([0.45, 0.25, 0.25, 0.15])
inset_ax2.plot(ddpg_unbalances, color='#E03C31', alpha=0.8)
inset_ax2.plot(td3_unbalances, color='#1D8C55', alpha=0.8)
inset_ax2.plot(sac_unbalances, color='#FFA500', alpha=0.8)
inset_ax2.plot(ppo_unbalances, color='#3498DB', alpha=0.8)
inset_ax2.plot(lmppo_unbalances, color='#A52A2A', alpha=0.8)
inset_ax2.set_xlim(900, 1000)
inset_ax2.set_ylim(0, 2)
inset_ax2.grid(True, which='both', axis='both', linestyle='--', linewidth=0.5, color='gray')
inset_ax2.tick_params(axis='both', which='major', labelsize=10)
ax2.grid(True, which='both', axis='both', linestyle='--', linewidth=0.5, color='gray')
ax2.legend(fontsize=12, handletextpad=0.5, labelspacing=0.3)
xy_main3 = (900, 0)
xy_inset3 = (900, 0)
xy_main4 = (1000, 0)
xy_inset4 = (1000, 0)
con3 = ConnectionPatch(xyA=xy_inset3, xyB=xy_main3, coordsA="data", coordsB="data",
axesA=inset_ax2, axesB=ax2, linestyle='--', linewidth=1, color='gray')
con4 = ConnectionPatch(xyA=xy_inset4, xyB=xy_main4, coordsA="data", coordsB="data",
axesA=inset_ax2, axesB=ax2, linestyle='--', linewidth=1, color='gray')
fig.add_artist(con3)
fig.add_artist(con4)
plt.show()
fig.savefig("compare.png", format='png', dpi=600, bbox_inches='tight')