This commit is contained in:
chenxiaodong 2024-06-24 14:18:06 +08:00
parent 3f99d287a6
commit a413b3ee0f
4 changed files with 61347 additions and 27 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="chenxd@124.16.151.196:22121 password" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="chenxd@124.16.151.196:22121 password">
<serverdata>

View File

@ -172,18 +172,6 @@ class AgentSAC(AgentBase):
actions = self.act.get_action(states)
return actions.detach().cpu().numpy()[0]
def explore_env(self, env, target_step):
trajectory = list()
state = self.state
for _ in range(target_step):
action = self.select_action(state)
state, next_state, reward, done, = env.step(action)
trajectory.append((state, (reward, done, *action)))
state = env.reset() if done else next_state
self.state = state
return trajectory
def update_net(self, buffer, batch_size, repeat_times, soft_update_tau):
buffer.update_now_len()

61322
data/action.json Normal file

File diff suppressed because it is too large Load Diff

38
test.py
View File

@ -21,21 +21,31 @@
# print(model)
import pickle
# import pickle
#
# a = 'DDPG'
# b = 'PPO'
# c = 'SAC'
# d = 'TD3'
#
# a1 = '/reward_data.pkl'
# a2 = '/loss_data.pkl'
# a3 = '/test_data.pkl'
#
# filename = './Agent' + a + a3
#
# # 使用 'rb' 模式打开文件,读取二进制数据
# with open(filename, 'rb') as f:
# data = pickle.load(f)
#
# print(data)
a = 'DDPG'
b = 'PPO'
c = 'SAC'
d = 'TD3'
import json
a1 = '/reward_data.pkl'
a2 = '/loss_data.pkl'
a3 = '/test_data.pkl'
with open('data/action.json', 'r') as file:
data = json.load(file)
filename = './Agent' + a + a3
# 遍历每组数据
for group in data:
print(group)
# 使用 'rb' 模式打开文件,读取二进制数据
with open(filename, 'rb') as f:
data = pickle.load(f)
print(data)