107 lines
3.3 KiB
Python
107 lines
3.3 KiB
Python
import itertools
|
||
import os
|
||
from joblib import dump, load
|
||
import pandas as pd
|
||
|
||
|
||
model_dict = {
|
||
"adb_ssa" : "model/SSA_ADB.joblib",
|
||
"dtr_ssa": "model/SSA_DTR.joblib",
|
||
"en_ssa":"model/SSA_ElasticNet.joblib",
|
||
"gp_ssa":"model/SSA_GaussianProcessRegressor.joblib",
|
||
"kn_ssa":"model/SSA_KNeighborsRegressor.joblib",
|
||
"lasso_ssa":"model/SSA_Lasso.joblib",
|
||
"lr_ssa":"model/SSA_LinearRegression.joblib",
|
||
"rfr_ssa":"model/SSA_RFR.joblib",
|
||
"ridge_ssa":"model/SSA_Ridge.joblib",
|
||
"svr_ssa":"model/SSA_SVR.joblib",
|
||
"xgb_ssa":"model/SSA_XGB.joblib",
|
||
"adb_tpv" : "model/TPV_ADB.joblib",
|
||
"gdbt_tpv": "model/TPV_GDBT.joblib",
|
||
"en_tpv":"model/TPV_ElasticNet.joblib",
|
||
"gp_tpv":"model/TPV_GaussianProcessRegressor.joblib",
|
||
"kn_tpv":"model/TPV_KNeighborsRegressor.joblib",
|
||
"lasso_tpv":"model/TPV_Lasso.joblib",
|
||
"lr_tpv":"model/TPV_LinearRegression.joblib",
|
||
"rfr_tpv":"model/TPV_RFR.joblib",
|
||
"ridge_tpv":"model/TPV_Ridge.joblib",
|
||
"svr_tpv":"model/TPV_SVR.joblib",
|
||
"xgb_tpv":"model/TPV_XGB.joblib"
|
||
}
|
||
|
||
model_list_ssa = ['xgb_ssa','lr_ssa','ridge_ssa','gp_ssa','en_ssa','kn_ssa','svr_ssa','dtr_ssa','rfr_ssa','adb_ssa']
|
||
model_list_tpv = ['xgb_tpv','lr_tpv','ridge_tpv','gp_tpv','en_tpv','kn_tpv','svr_tpv','dtr_tpv','rfr_tpv','adb_tpv']
|
||
|
||
# xgb, lr, ridge,gp,en,kn,svr,dtr,rfr, adb
|
||
ssa_mae = [308,407,407,282,411,389,322,288,192,329]
|
||
ssa_r2 = [0.92,0.82,0.82,0.89,0.81,0.82,0.87,0.88,0.95,0.88]
|
||
|
||
# xgb, lr, ridge,gp,en,kn,svr,gdbt,,rfr, adb
|
||
tpv_mae = [0.2,0.2,0.2,0.2,0.2,0.3,0.23,0.23,0.16,0.21]
|
||
tpv_r2 = [0.81,0.81,0.81,0.8,0.82,0.72,0.78,0.8,0.85,0.84]
|
||
|
||
|
||
|
||
|
||
test_content = {
|
||
'A': 21.01,
|
||
'VM' : 8.17,
|
||
'K/C': 4,
|
||
'MM': 0,
|
||
'AT': 700,
|
||
'At': 2,
|
||
'Rt': 5
|
||
}
|
||
|
||
test_content = pd.DataFrame([test_content])
|
||
current_script_directory = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
def pred_single_ssa(test_content):
|
||
ssa_pred = []
|
||
for name in model_list_ssa:
|
||
model_path = os.path.join(current_script_directory,model_dict[name])
|
||
ssa_model = load(model_path)
|
||
pred = ssa_model.predict(test_content)
|
||
ssa_pred.append(pred[0])
|
||
return ssa_pred
|
||
|
||
|
||
|
||
def pred_single_tpv(test_content):
|
||
tpv_pred = []
|
||
for name in model_list_tpv:
|
||
model_path = os.path.join(current_script_directory,model_dict[name])
|
||
tpv_model = load(model_path)
|
||
pred = tpv_model.predict(test_content)
|
||
tpv_pred.append(pred[0])
|
||
return tpv_pred
|
||
|
||
|
||
def choose_model(name,data):
|
||
model_path = os.path.join(current_script_directory,model_dict[name])
|
||
model = load(model_path)
|
||
pred = model.predict(data)
|
||
return pred
|
||
|
||
|
||
def get_excel():
|
||
current_script_directory = os.path.dirname(os.path.abspath(__file__))
|
||
data_path = os.path.join(current_script_directory,'data/TPV.xlsx')
|
||
test_data = pd.read_excel(data_path,sheet_name='sheet2')
|
||
test_data = test_data.drop('TPV', axis=1)
|
||
print(test_data.shape)
|
||
pred = choose_model('xgb_tpv',test_data)
|
||
test_data['TPV'] = pred
|
||
print(test_data)
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# ssa_result = pred_single_ssa(test_content)
|
||
# result = [ssa_mae, ssa_r2,ssa_result]
|
||
# print(result)
|
||
|
||
# print(pred_single_tpv(test_content))
|
||
get_excel()
|
||
|