action
This commit is contained in:
parent
3f99d287a6
commit
a413b3ee0f
|
@ -1,6 +1,6 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="chenxd@124.16.151.196:22121 password" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="chenxd@124.16.151.196:22121 password">
|
<paths name="chenxd@124.16.151.196:22121 password">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
|
|
12
agent.py
12
agent.py
|
@ -172,18 +172,6 @@ class AgentSAC(AgentBase):
|
||||||
actions = self.act.get_action(states)
|
actions = self.act.get_action(states)
|
||||||
return actions.detach().cpu().numpy()[0]
|
return actions.detach().cpu().numpy()[0]
|
||||||
|
|
||||||
def explore_env(self, env, target_step):
|
|
||||||
trajectory = list()
|
|
||||||
|
|
||||||
state = self.state
|
|
||||||
for _ in range(target_step):
|
|
||||||
action = self.select_action(state)
|
|
||||||
state, next_state, reward, done, = env.step(action)
|
|
||||||
trajectory.append((state, (reward, done, *action)))
|
|
||||||
state = env.reset() if done else next_state
|
|
||||||
self.state = state
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
def update_net(self, buffer, batch_size, repeat_times, soft_update_tau):
|
def update_net(self, buffer, batch_size, repeat_times, soft_update_tau):
|
||||||
buffer.update_now_len()
|
buffer.update_now_len()
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
38
test.py
38
test.py
|
@ -21,21 +21,31 @@
|
||||||
|
|
||||||
# print(model)
|
# print(model)
|
||||||
|
|
||||||
import pickle
|
# import pickle
|
||||||
|
#
|
||||||
|
# a = 'DDPG'
|
||||||
|
# b = 'PPO'
|
||||||
|
# c = 'SAC'
|
||||||
|
# d = 'TD3'
|
||||||
|
#
|
||||||
|
# a1 = '/reward_data.pkl'
|
||||||
|
# a2 = '/loss_data.pkl'
|
||||||
|
# a3 = '/test_data.pkl'
|
||||||
|
#
|
||||||
|
# filename = './Agent' + a + a3
|
||||||
|
#
|
||||||
|
# # 使用 'rb' 模式打开文件,读取二进制数据
|
||||||
|
# with open(filename, 'rb') as f:
|
||||||
|
# data = pickle.load(f)
|
||||||
|
#
|
||||||
|
# print(data)
|
||||||
|
|
||||||
a = 'DDPG'
|
import json
|
||||||
b = 'PPO'
|
|
||||||
c = 'SAC'
|
|
||||||
d = 'TD3'
|
|
||||||
|
|
||||||
a1 = '/reward_data.pkl'
|
with open('data/action.json', 'r') as file:
|
||||||
a2 = '/loss_data.pkl'
|
data = json.load(file)
|
||||||
a3 = '/test_data.pkl'
|
|
||||||
|
|
||||||
filename = './Agent' + a + a3
|
# 遍历每组数据
|
||||||
|
for group in data:
|
||||||
|
print(group)
|
||||||
|
|
||||||
# 使用 'rb' 模式打开文件,读取二进制数据
|
|
||||||
with open(filename, 'rb') as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
|
|
||||||
print(data)
|
|
||||||
|
|
Loading…
Reference in New Issue