123 lines
5.2 KiB
Python
123 lines
5.2 KiB
Python
import itertools
|
|
import os
|
|
from joblib import dump, load
|
|
import pandas as pd
|
|
|
|
def float_range(start, end, step):
|
|
while start <= end:
|
|
yield round(start, 10) # 使用 round 来避免浮点数精度问题
|
|
start += step
|
|
|
|
|
|
def get_params(params):
|
|
# 灰分
|
|
params['A_min'] = params['A_min'] if params['A_min'] is not None else 4
|
|
params['A_max'] = params['A_max'] if params['A_max'] is not None else 48
|
|
params['A_step'] = params['A_step'] if params['A_step'] is not None else 4
|
|
# 挥发分
|
|
params['VM_min'] = params['VM_min'] if params['VM_min'] is not None else 5
|
|
params['VM_max'] = params['VM_max'] if params['VM_max'] is not None else 50
|
|
params['VM_step'] = params['VM_step'] if params['VM_step'] is not None else 5
|
|
# 活化剂比例
|
|
params['KC_min'] = params['KC_min'] if params['KC_min'] is not None else 1
|
|
params['KC_max'] = params['KC_max'] if params['KC_max'] is not None else 4
|
|
params['KC_step'] = params['KC_step'] if params['KC_step'] is not None else 0.5
|
|
# 混合方式
|
|
params['MM_min'] = params['MM_min'] if params['MM_min'] is not None else 0
|
|
params['MM_max'] = params['MM_max'] if params['MM_max'] is not None else 1
|
|
params['MM_step'] = params['MM_step'] if params['MM_step'] is not None else 1
|
|
# 活化温度
|
|
params['AT_min'] = params['AT_min'] if params['AT_min'] is not None else 600
|
|
params['AT_max'] = params['AT_max'] if params['AT_max'] is not None else 900
|
|
params['AT_step'] = params['AT_step'] if params['AT_step'] is not None else 50
|
|
# 活化时间
|
|
params['At_min'] = params['At_min'] if params['At_min'] is not None else 0.5
|
|
params['At_max'] = params['At_max'] if params['At_max'] is not None else 2
|
|
params['At_step'] = params['At_step'] if params['At_step'] is not None else 0.5
|
|
# 升温速率
|
|
params['Rt_min'] = params['Rt_min'] if params['Rt_min'] is not None else 5
|
|
params['Rt_max'] = params['Rt_max'] if params['Rt_max'] is not None else 10
|
|
params['Rt_step'] = params['Rt_step'] if params['Rt_step'] is not None else 5
|
|
|
|
return params
|
|
|
|
def create_pred_data(params):
|
|
sequence_A = list(range(params['A_min'], params['A_max'] + 1, params['A_step']))
|
|
sequence_VM = list(range(params['VM_min'], params['VM_max'] + 1, params['VM_step']))
|
|
sequence_KC = list(float_range(params['KC_min'], params['KC_max'], params['KC_step']))
|
|
sequence_MM = list(range(params['MM_min'], params['MM_max'] + 1, params['MM_step']))
|
|
sequence_AT = list(range(params['AT_min'], params['AT_max'] + 1, params['AT_step']))
|
|
sequence_At = list(float_range(params['At_min'], params['At_max'], params['At_step']))
|
|
sequence_Rt = list(range(params['Rt_min'], params['Rt_max'] + 1, params['Rt_step']))
|
|
|
|
training_samples = list(itertools.product(sequence_A, sequence_VM, sequence_KC, sequence_MM, sequence_AT, sequence_At, sequence_Rt))
|
|
return training_samples
|
|
|
|
model_dict = {
|
|
"adb" : ["model/SSA_ADB.joblib","model/TPV_ADB.joblib"],
|
|
"dtr": ["model/SSA_DTR.joblib","model/TPV_DTR.joblib"],
|
|
"en":["model/SSA_ElasticNet.joblib","model/TPV_ElasticNet.joblib"],
|
|
"gp":["model/SSA_GaussianProcessRegressor.joblib","model/TPV_GaussianProcessRegressor.joblib"],
|
|
"kn":["model/SSA_KNeighborsRegressor.joblib","model/TPV_KNeighborsRegressor.joblib"],
|
|
"lasso":["model/SSA_Lasso.joblib","model/TPV_Lasso.joblib"],
|
|
"lr":["model/SSA_LinearRegression.joblib","model/TPV_LinearRegression.joblib"],
|
|
"rfr":["model/SSA_RFR.joblib","model/TPV_RFR.joblib"],
|
|
"ridge":["model/SSA_Ridge.joblib","model/TPV_Ridge.joblib"],
|
|
"svr":["model/SSA_SVR.joblib","model/TPV_SVR.joblib"],
|
|
"xgb":["model/SSA_XGB.joblib","model/TPV_XGB.joblib"]
|
|
}
|
|
|
|
|
|
def pred_func(func_name, pred_data):
|
|
current_script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
ssa_model_path = os.path.join(current_script_directory,model_dict[func_name][0])
|
|
tpv_model_path = os.path.join(current_script_directory,model_dict[func_name][1])
|
|
|
|
|
|
ssa_model = load(ssa_model_path)
|
|
tpv_model = load(tpv_model_path)
|
|
# 使用模型预测
|
|
pred_data = pd.DataFrame(pred_data,columns=['A','VM','K/C','MM','AT','At','Rt'])
|
|
pred_ssa = ssa_model.predict(pred_data)
|
|
pred_tpv = tpv_model.predict(pred_data)
|
|
|
|
result = pd.DataFrame({
|
|
"SSA":pred_ssa,
|
|
"TPV":pred_tpv
|
|
})
|
|
|
|
result = pd.concat([pred_data,result], axis=1)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
params = {"A_min" : None, "A_max": None, "A_step": None, "VM_min" : None, "VM_max": None, "VM_step": None,
|
|
"KC_min" : None, "KC_max": None, "KC_step": None,"MM_min" : None, "MM_max": None, "MM_step": None,
|
|
"AT_min" : None, "AT_max": None, "AT_step": None,"At_min" : None, "At_max": None, "At_step": None,
|
|
"Rt_min" : None, "Rt_max": None, "Rt_step": None}
|
|
params = get_params(params)
|
|
pred_data = create_pred_data(params)
|
|
result = pred_func("ridge",pred_data)
|
|
sorted_result = result.sort_values(by=['SSA', 'TPV'], ascending=[False, False])
|
|
# 保留条数
|
|
num = 6
|
|
if num is None:
|
|
print(sorted_result.head()) # 返回全部
|
|
else:
|
|
print(sorted_result.head(6)) # 返回所需条数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|