action
This commit is contained in:
parent
3f99d287a6
commit
a413b3ee0f
|
@ -1,6 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="chenxd@124.16.151.196:22121 password" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="chenxd@124.16.151.196:22121 password">
|
||||
<serverdata>
|
||||
|
|
12
agent.py
12
agent.py
|
@ -172,18 +172,6 @@ class AgentSAC(AgentBase):
|
|||
actions = self.act.get_action(states)
|
||||
return actions.detach().cpu().numpy()[0]
|
||||
|
||||
def explore_env(self, env, target_step):
|
||||
trajectory = list()
|
||||
|
||||
state = self.state
|
||||
for _ in range(target_step):
|
||||
action = self.select_action(state)
|
||||
state, next_state, reward, done, = env.step(action)
|
||||
trajectory.append((state, (reward, done, *action)))
|
||||
state = env.reset() if done else next_state
|
||||
self.state = state
|
||||
return trajectory
|
||||
|
||||
def update_net(self, buffer, batch_size, repeat_times, soft_update_tau):
|
||||
buffer.update_now_len()
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
38
test.py
38
test.py
|
@ -21,21 +21,31 @@
|
|||
|
||||
# print(model)
|
||||
|
||||
import pickle
|
||||
# import pickle
|
||||
#
|
||||
# a = 'DDPG'
|
||||
# b = 'PPO'
|
||||
# c = 'SAC'
|
||||
# d = 'TD3'
|
||||
#
|
||||
# a1 = '/reward_data.pkl'
|
||||
# a2 = '/loss_data.pkl'
|
||||
# a3 = '/test_data.pkl'
|
||||
#
|
||||
# filename = './Agent' + a + a3
|
||||
#
|
||||
# # 使用 'rb' 模式打开文件,读取二进制数据
|
||||
# with open(filename, 'rb') as f:
|
||||
# data = pickle.load(f)
|
||||
#
|
||||
# print(data)
|
||||
|
||||
a = 'DDPG'
|
||||
b = 'PPO'
|
||||
c = 'SAC'
|
||||
d = 'TD3'
|
||||
import json
|
||||
|
||||
a1 = '/reward_data.pkl'
|
||||
a2 = '/loss_data.pkl'
|
||||
a3 = '/test_data.pkl'
|
||||
with open('data/action.json', 'r') as file:
|
||||
data = json.load(file)
|
||||
|
||||
filename = './Agent' + a + a3
|
||||
# 遍历每组数据
|
||||
for group in data:
|
||||
print(group)
|
||||
|
||||
# 使用 'rb' 模式打开文件,读取二进制数据
|
||||
with open(filename, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(data)
|
||||
|
|
Loading…
Reference in New Issue