Compare commits

..

No commits in common. "b5a1842147aac4c66a94b2db7cfc4069dd849157" and "53d3ac9ca828eca346ce980fedd887acb3007219" have entirely different histories.

10 changed files with 87 additions and 53 deletions

View File

@ -24,13 +24,6 @@
</Attribute> </Attribute>
</value> </value>
</entry> </entry>
<entry key="\data\station.csv">
<value>
<Attribute>
<option name="separator" value="," />
</Attribute>
</value>
</entry>
<entry key="\data\temper.csv"> <entry key="\data\temper.csv">
<value> <value>
<Attribute> <Attribute>

24
DDPG.py
View File

@ -7,9 +7,22 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
'''record real unbalance''' '''here record real unbalance'''
reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []}
loss_record = {'episode': [], 'steps': [], 'critic_loss': [], 'actor_loss': [], 'entropy_loss': []} loss_record = {'episode': [], 'steps': [], 'critic_loss': [], 'actor_loss': [], 'entropy_loss': []}
args.visible_gpu = '0' args.visible_gpu = '0'
@ -19,6 +32,7 @@ if __name__ == '__main__':
agent_name = f'{args.agent.__class__.__name__}' agent_name = f'{args.agent.__class__.__name__}'
args.agent.cri_target = True args.agent.cri_target = True
args.env = ESSEnv() args.env = ESSEnv()
# creat lists of lists/or creat a long list?
args.init_before_training(if_main=True) args.init_before_training(if_main=True)
'''init agent and environment''' '''init agent and environment'''
agent = args.agent agent = args.agent
@ -43,7 +57,7 @@ if __name__ == '__main__':
# args.save_network=False # args.save_network=False
# args.test_network=False # args.test_network=False
# args.save_test_data=False # args.save_test_data=False
# args.compare_with_gurobi=False # args.compare_with_pyomo=False
if args.train: if args.train:
collect_data = True collect_data = True
@ -97,8 +111,8 @@ if __name__ == '__main__':
with open(test_data_save_path, 'wb') as tf: with open(test_data_save_path, 'wb') as tf:
pickle.dump(record, tf) pickle.dump(record, tf)
'''compare with gurobi data and results''' '''compare with pyomo data and results'''
if args.compare_with_gurobi: if args.compare_with_pyomo:
month = record['init_info'][0][0] month = record['init_info'][0][0]
day = record['init_info'][0][1] day = record['init_info'][0][1]
initial_soc = record['init_info'][0][3] initial_soc = record['init_info'][0][3]
@ -112,7 +126,7 @@ if __name__ == '__main__':
plot_dir = make_dir(args.cwd, plot_args.feature_change) plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir) plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir)
'''compare the different cost get from gurobi and DDPG''' '''compare the different cost get from pyomo and DDPG'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('operation_cost_sum:', sum(eval_data['operation_cost'])) print('operation_cost_sum:', sum(eval_data['operation_cost']))
print('step_cost_sum:', sum(base_result['step_cost'])) print('step_cost_sum:', sum(base_result['step_cost']))

12
PPO.py
View File

@ -241,7 +241,7 @@ class Arguments:
self.num_threads = 32 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads) self.num_threads = 32 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads)
'''Arguments for training''' '''Arguments for training'''
self.num_episode = 1000 # to control the train episodes for PPO self.num_episode = 2000 # to control the train episodes for PPO
self.gamma = 0.995 # discount factor of future rewards self.gamma = 0.995 # discount factor of future rewards
self.learning_rate = 2 ** -14 # 2e-4 self.learning_rate = 2 ** -14 # 2e-4
self.soft_update_tau = 2 ** -8 # 2 ** -8 ~= 5e-3 self.soft_update_tau = 2 ** -8 # 2 ** -8 ~= 5e-3
@ -261,7 +261,7 @@ class Arguments:
self.save_network = True self.save_network = True
self.test_network = True self.test_network = True
self.save_test_data = True self.save_test_data = True
self.compare_with_gurobi = True self.compare_with_pyomo = True
self.plot_on = True self.plot_on = True
def init_before_training(self, if_main): def init_before_training(self, if_main):
@ -336,7 +336,7 @@ if __name__ == '__main__':
# args.save_network=False # args.save_network=False
# args.test_network=False # args.test_network=False
# args.save_test_data=False # args.save_test_data=False
# args.compare_with_gurobi=False # args.compare_with_pyomo=False
if args.train: if args.train:
for i_episode in range(num_episode): for i_episode in range(num_episode):
with torch.no_grad(): with torch.no_grad():
@ -379,8 +379,8 @@ if __name__ == '__main__':
with open(test_data_save_path, 'wb') as tf: with open(test_data_save_path, 'wb') as tf:
pickle.dump(record, tf) pickle.dump(record, tf)
'''compare with gurobi data and results''' '''compare with pyomo data and results'''
if args.compare_with_gurobi: if args.compare_with_pyomo:
month = record['init_info'][0][0] month = record['init_info'][0][0]
day = record['init_info'][0][1] day = record['init_info'][0][1]
initial_soc = record['init_info'][0][3] initial_soc = record['init_info'][0][3]
@ -394,7 +394,7 @@ if __name__ == '__main__':
plot_dir = make_dir(args.cwd, plot_args.feature_change) plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir) plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir)
'''compare the different cost get from gurobi and PPO''' '''compare the different cost get from pyomo and PPO'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('operation_cost_sum:', sum(eval_data['operation_cost'])) print('operation_cost_sum:', sum(eval_data['operation_cost']))
print('step_cost_sum:', sum(base_result['step_cost'])) print('step_cost_sum:', sum(base_result['step_cost']))

21
SAC.py
View File

@ -7,6 +7,19 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []}
@ -46,7 +59,7 @@ if __name__ == '__main__':
# args.save_network=False # args.save_network=False
# args.test_network=False # args.test_network=False
# args.save_test_data=False # args.save_test_data=False
# args.compare_with_gurobi=False # args.compare_with_pyomo=False
# #
if args.train: if args.train:
collect_data = True collect_data = True
@ -102,8 +115,8 @@ if __name__ == '__main__':
with open(test_data_save_path, 'wb') as tf: with open(test_data_save_path, 'wb') as tf:
pickle.dump(record, tf) pickle.dump(record, tf)
'''compare with gurobi data and results''' '''compare with pyomo data and results'''
if args.compare_with_gurobi: if args.compare_with_pyomo:
month = record['init_info'][0][0] month = record['init_info'][0][0]
day = record['init_info'][0][1] day = record['init_info'][0][1]
initial_soc = record['init_info'][0][3] initial_soc = record['init_info'][0][3]
@ -117,7 +130,7 @@ if __name__ == '__main__':
plot_dir = make_dir(args.cwd, plot_args.feature_change) plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir) plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir)
'''compare the different cost get from gurobi and SAC''' '''compare the different cost get from pyomo and SAC'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('operation_cost_sum:', sum(eval_data['operation_cost'])) print('operation_cost_sum:', sum(eval_data['operation_cost']))
print('step_cost_sum:', sum(base_result['step_cost'])) print('step_cost_sum:', sum(base_result['step_cost']))

21
TD3.py
View File

@ -7,6 +7,19 @@ from environment import ESSEnv
from tools import * from tools import *
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
if __name__ == '__main__': if __name__ == '__main__':
args = Arguments() args = Arguments()
reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []} reward_record = {'episode': [], 'steps': [], 'mean_episode_reward': [], 'unbalance': []}
@ -43,7 +56,7 @@ if __name__ == '__main__':
# args.save_network=False # args.save_network=False
# args.test_network=False # args.test_network=False
# args.save_test_data=False # args.save_test_data=False
# args.compare_with_gurobi=False # args.compare_with_pyomo=False
if args.train: if args.train:
collect_data = True collect_data = True
@ -98,8 +111,8 @@ if __name__ == '__main__':
with open(test_data_save_path, 'wb') as tf: with open(test_data_save_path, 'wb') as tf:
pickle.dump(record, tf) pickle.dump(record, tf)
'''compare with gurobi data and results''' '''compare with pyomo data and results'''
if args.compare_with_gurobi: if args.compare_with_pyomo:
month = record['init_info'][0][0] month = record['init_info'][0][0]
day = record['init_info'][0][1] day = record['init_info'][0][1]
initial_soc = record['init_info'][0][3] initial_soc = record['init_info'][0][3]
@ -113,7 +126,7 @@ if __name__ == '__main__':
plot_dir = make_dir(args.cwd, plot_args.feature_change) plot_dir = make_dir(args.cwd, plot_args.feature_change)
plot_optimization_result(base_result, plot_dir) plot_optimization_result(base_result, plot_dir)
plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir) plot_evaluation_information(args.cwd + '/' + 'test_data.pkl', plot_dir)
'''compare the different cost get from gurobi and TD3''' '''compare the different cost get from pyomo and TD3'''
ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost']) ration = sum(eval_data['operation_cost']) / sum(base_result['step_cost'])
print('operation_cost_sum:', sum(eval_data['operation_cost'])) print('operation_cost_sum:', sum(eval_data['operation_cost']))
print('step_cost_sum:', sum(base_result['step_cost'])) print('step_cost_sum:', sum(base_result['step_cost']))

View File

@ -5,12 +5,15 @@ class Constant:
class DataManager: class DataManager:
def __init__(self) -> None: def __init__(self) -> None:
self.Pv = []
self.Prices = [] self.Prices = []
self.Load_Consumption = [] self.Load_Consumption = []
self.Temperature = [] self.Temperature = []
self.Irradiance = [] self.Irradiance = []
self.Wind = [] self.Wind = []
def add_pv_element(self, element): self.Pv.append(element)
def add_price_element(self, element): self.Prices.append(element) def add_price_element(self, element): self.Prices.append(element)
def add_load_element(self, element): self.Load_Consumption.append(element) def add_load_element(self, element): self.Load_Consumption.append(element)
@ -22,6 +25,9 @@ class DataManager:
def add_wind_element(self, element): self.Wind.append(element) def add_wind_element(self, element): self.Wind.append(element)
# get current time data based on given month day, and day_time # get current time data based on given month day, and day_time
def get_pv_data(self, month, day, day_time):
return self.Pv[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + day_time]
def get_price_data(self, month, day, day_time): def get_price_data(self, month, day, day_time):
return self.Prices[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + day_time] return self.Prices[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + day_time]
@ -38,6 +44,10 @@ class DataManager:
return self.Wind[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + day_time] return self.Wind[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + day_time]
# get series data for one episode # get series data for one episode
def get_series_pv_data(self, month, day):
return self.Pv[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24:
(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + 24]
def get_series_price_data(self, month, day): def get_series_price_data(self, month, day):
return self.Prices[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24: return self.Prices[(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24:
(sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + 24] (sum(Constant.MONTHS_LEN[:month - 1]) + day - 1) * 24 + 24]

View File

@ -11,8 +11,8 @@ class ESSEnv(gym.Env):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(ESSEnv, self).__init__() super(ESSEnv, self).__init__()
self.excess = None self.excess = None
self.shedding = None
self.unbalance = None self.unbalance = None
self.shedding = None
self.real_unbalance = None self.real_unbalance = None
self.operation_cost = None self.operation_cost = None
self.current_output = None self.current_output = None
@ -113,7 +113,7 @@ class ESSEnv(gym.Env):
sell_benefit = self.grid.get_cost(price, unbalance) * self.sell_coefficient sell_benefit = self.grid.get_cost(price, unbalance) * self.sell_coefficient
else: else:
sell_benefit = self.grid.get_cost(price, self.grid.exchange_ability) * self.sell_coefficient sell_benefit = self.grid.get_cost(price, self.grid.exchange_ability) * self.sell_coefficient
# real unbalance that grid could not meet # real unbalance that even grid could not meet
self.excess = unbalance - self.grid.exchange_ability self.excess = unbalance - self.grid.exchange_ability
excess_penalty = self.excess * self.penalty_coefficient excess_penalty = self.excess * self.penalty_coefficient
else: # unbalance <0, its load shedding model, deficient penalty is used else: # unbalance <0, its load shedding model, deficient penalty is used
@ -153,12 +153,14 @@ class ESSEnv(gym.Env):
# format(self.day, self.current_time, current_obs, next_obs, reward, finish)) # format(self.day, self.current_time, current_obs, next_obs, reward, finish))
def _load_year_data(self): def _load_year_data(self):
# pv_df = pd.read_csv('data/pv.csv', sep=',')
price_df = pd.read_csv('data/prices.csv', sep=';') price_df = pd.read_csv('data/prices.csv', sep=';')
load_df = pd.read_csv('data/houseload.csv', sep=',') load_df = pd.read_csv('data/houseload.csv', sep=',')
irradiance_df = pd.read_csv('data/irradiance.csv', sep=',') irradiance_df = pd.read_csv('data/irradiance.csv', sep=',')
temperature_df = pd.read_csv('data/temper.csv', sep=',') temperature_df = pd.read_csv('data/temper.csv', sep=',')
wind_df = pd.read_csv('data/wind.csv', sep=',') wind_df = pd.read_csv('data/wind.csv', sep=',')
# pv = pv_df['pv'].to_numpy(dtype=float)
price = price_df['Price'].apply(lambda x: x.replace(',', '.')).to_numpy(dtype=float) price = price_df['Price'].apply(lambda x: x.replace(',', '.')).to_numpy(dtype=float)
load = load_df['houseload'].to_numpy(dtype=float) load = load_df['houseload'].to_numpy(dtype=float)
irradiance = irradiance_df['irradiance'].to_numpy(dtype=float) irradiance = irradiance_df['irradiance'].to_numpy(dtype=float)
@ -171,6 +173,7 @@ class ESSEnv(gym.Env):
transformed_element = transform_function(element) transformed_element = transform_function(element)
add_function(transformed_element) add_function(transformed_element)
# process_elements(pv, lambda x: x, self.data_manager.add_pv_element)
process_elements(price, lambda x: max(x / 10, 0.5), self.data_manager.add_price_element) process_elements(price, lambda x: max(x / 10, 0.5), self.data_manager.add_price_element)
process_elements(load, lambda x: x * 3, self.data_manager.add_load_element) process_elements(load, lambda x: x * 3, self.data_manager.add_load_element)
process_elements(irradiance, lambda x: x, self.data_manager.add_irradiance_element) process_elements(irradiance, lambda x: x, self.data_manager.add_irradiance_element)

View File

@ -78,9 +78,9 @@ class Solar:
self.oc_voltage = parameters['V_oc0'] self.oc_voltage = parameters['V_oc0']
self.s_resistance = parameters['R_s'] self.s_resistance = parameters['R_s']
self.sh_resistance = parameters['R_sh'] self.sh_resistance = parameters['R_sh']
self.temper_coefficient = parameters['T_c'] self.temper_coefficient = parameters['k_v']
self.opex_cofficient = parameters['O_c'] self.opex_cofficient = parameters['k_o']
self.refer_irradiance = parameters['I_ref'] self.refer_irradiance = parameters['G_ref']
self.refer_temperature = parameters['T_ref'] self.refer_temperature = parameters['T_ref']
def step(self, temperature, irradiance, action_voltage=0): def step(self, temperature, irradiance, action_voltage=0):
@ -126,6 +126,7 @@ class Wind:
self.power_coefficient * self.generator_efficiency) / 1e3 self.power_coefficient * self.generator_efficiency) / 1e3
else: else:
self.current_power = 0 self.current_power = 0
return self.current_power return self.current_power
def gen_cost(self, current_power): def gen_cost(self, current_power):

View File

@ -2,13 +2,13 @@ import numpy as np
solar_parameters = { solar_parameters = {
'I_sc0': 8.0, # 参考条件下的短路电流 (A) 'I_sc0': 8.0, # 参考条件下的短路电流 (A)
'V_b': 25, # 基准电压 'V_b': 24, # 基准电压
'V_oc0': 36.0, # 参考条件下的开路电压 (V) 'V_oc0': 36.0, # 参考条件下的开路电压 (V)
'R_s': 0.1, # 串联电阻 (Ω) 'R_s': 0.1, # 串联电阻 (Ω)
'R_sh': 100.0, # 并联电阻 (Ω) 'R_sh': 100.0, # 并联电阻 (Ω)
'T_c': -0.2, # 开路电压的温度系数 (V/°C) 'k_v': -0.2, # 开路电压的温度系数 (V/°C)
'O_c': 0.001, # 变动成本系数 (元/千瓦时) 'k_o': 0.001, # 变动成本系数 (元/千瓦时)
'I_ref': 1000, # 参考辐照度 (W/m²) 'G_ref': 1000, # 参考辐照度 (W/m²)
'T_ref': 25, # 参考温度 (°C) 'T_ref': 25, # 参考温度 (°C)
} }
@ -35,15 +35,15 @@ battery_parameters = {
} }
dg_parameters = { dg_parameters = {
'gen_1': {'a': 0.0034, 'b': 3, 'c': 30, 'gen_1': {'a': 0.0034, 'b': 3, 'c': 30, 'd': 0.03, 'e': 4.2, 'f': 0.031,
'power_output_max': 150, 'power_output_min': 10, 'power_output_max': 150, 'power_output_min': 10,
'ramping_up': 100, 'ramping_down': 100, 'min_up': 2, 'min_down': 1}, 'ramping_up': 100, 'ramping_down': 100, 'min_up': 2, 'min_down': 1},
'gen_2': {'a': 0.001, 'b': 10, 'c': 40, 'gen_2': {'a': 0.001, 'b': 10, 'c': 40, 'd': 0.03, 'e': 4.2, 'f': 0.031,
'power_output_max': 375, 'power_output_min': 50, 'power_output_max': 375, 'power_output_min': 50,
'ramping_up': 100, 'ramping_down': 100, 'min_up': 2, 'min_down': 1}, 'ramping_up': 100, 'ramping_down': 100, 'min_up': 2, 'min_down': 1},
'gen_3': {'a': 0.001, 'b': 15, 'c': 70, 'gen_3': {'a': 0.001, 'b': 15, 'c': 70, 'd': 0.03, 'e': 4.2, 'f': 0.031,
'power_output_max': 500, 'power_output_min': 100, 'power_output_max': 500, 'power_output_min': 100,
'ramping_up': 200, 'ramping_down': 200, 'min_up': 2, 'min_down': 1} 'ramping_up': 200, 'ramping_down': 200, 'min_up': 2, 'min_down': 1}
} }

View File

@ -129,7 +129,7 @@ class Arguments:
self.num_threads = 32 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads) self.num_threads = 32 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads)
'''Arguments for training''' '''Arguments for training'''
self.num_episode = 1000 self.num_episode = 2000
self.gamma = 0.995 # discount factor of future rewards self.gamma = 0.995 # discount factor of future rewards
# self.reward_scale = 1 # an approximate target reward usually be closed to 256 # self.reward_scale = 1 # an approximate target reward usually be closed to 256
self.learning_rate = 2 ** -14 # 2 ** -14 ~= 6e-5 self.learning_rate = 2 ** -14 # 2 ** -14 ~= 6e-5
@ -152,7 +152,7 @@ class Arguments:
self.save_network = True self.save_network = True
self.test_network = True self.test_network = True
self.save_test_data = True self.save_test_data = True
self.compare_with_gurobi = True self.compare_with_pyomo = True
self.plot_on = True self.plot_on = True
def init_before_training(self, if_main): def init_before_training(self, if_main):
@ -233,19 +233,6 @@ def get_episode_return(env, act, device):
return episode_return, episode_unbalance return episode_return, episode_unbalance
def update_buffer(_trajectory):
ten_state = torch.as_tensor([item[0] for item in _trajectory], dtype=torch.float32)
ary_other = torch.as_tensor([item[1] for item in _trajectory])
ary_other[:, 0] = ary_other[:, 0] # ten_reward
ary_other[:, 1] = (1.0 - ary_other[:, 1]) * gamma # ten_mask = (1.0 - ary_done) * gamma
buffer.extend_buffer(ten_state, ary_other)
_steps = ten_state.shape[0]
_r_exp = ary_other[:, 0].mean() # other = (reward, mask, action)
return _steps, _r_exp
class ReplayBuffer: class ReplayBuffer:
def __init__(self, max_len, state_dim, action_dim, gpu_id=0): def __init__(self, max_len, state_dim, action_dim, gpu_id=0):
self.now_len = 0 self.now_len = 0