ai-station-code/meijitancailiao/model_pred.py

107 lines
3.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import itertools
import os
from joblib import dump, load
import pandas as pd
model_dict = {
"adb_ssa" : "model/SSA_ADB.joblib",
"dtr_ssa": "model/SSA_DTR.joblib",
"en_ssa":"model/SSA_ElasticNet.joblib",
"gp_ssa":"model/SSA_GaussianProcessRegressor.joblib",
"kn_ssa":"model/SSA_KNeighborsRegressor.joblib",
"lasso_ssa":"model/SSA_Lasso.joblib",
"lr_ssa":"model/SSA_LinearRegression.joblib",
"rfr_ssa":"model/SSA_RFR.joblib",
"ridge_ssa":"model/SSA_Ridge.joblib",
"svr_ssa":"model/SSA_SVR.joblib",
"xgb_ssa":"model/SSA_XGB.joblib",
"adb_tpv" : "model/TPV_ADB.joblib",
"gdbt_tpv": "model/TPV_GDBT.joblib",
"en_tpv":"model/TPV_ElasticNet.joblib",
"gp_tpv":"model/TPV_GaussianProcessRegressor.joblib",
"kn_tpv":"model/TPV_KNeighborsRegressor.joblib",
"lasso_tpv":"model/TPV_Lasso.joblib",
"lr_tpv":"model/TPV_LinearRegression.joblib",
"rfr_tpv":"model/TPV_RFR.joblib",
"ridge_tpv":"model/TPV_Ridge.joblib",
"svr_tpv":"model/TPV_SVR.joblib",
"xgb_tpv":"model/TPV_XGB.joblib"
}
model_list_ssa = ['xgb_ssa','lr_ssa','ridge_ssa','gp_ssa','en_ssa','kn_ssa','svr_ssa','dtr_ssa','rfr_ssa','adb_ssa']
model_list_tpv = ['xgb_tpv','lr_tpv','ridge_tpv','gp_tpv','en_tpv','kn_tpv','svr_tpv','dtr_tpv','rfr_tpv','adb_tpv']
# xgb, lr, ridge,gp,en,kn,svr,dtr,rfr, adb
ssa_mae = [308,407,407,282,411,389,322,288,192,329]
ssa_r2 = [0.92,0.82,0.82,0.89,0.81,0.82,0.87,0.88,0.95,0.88]
#  xgb, lr, ridge,gp,en,kn,svrgdbt,,rfr, adb
tpv_mae = [0.2,0.2,0.2,0.2,0.2,0.3,0.23,0.23,0.16,0.21]
tpv_r2 = [0.81,0.81,0.81,0.8,0.82,0.72,0.78,0.8,0.85,0.84]
test_content = {
'A': 21.01,
'VM' : 8.17,
'K/C': 4,
'MM': 0,
'AT': 700,
'At': 2,
'Rt': 5
}
test_content = pd.DataFrame([test_content])
current_script_directory = os.path.dirname(os.path.abspath(__file__))
def pred_single_ssa(test_content):
ssa_pred = []
for name in model_list_ssa:
model_path = os.path.join(current_script_directory,model_dict[name])
ssa_model = load(model_path)
pred = ssa_model.predict(test_content)
ssa_pred.append(pred[0])
return ssa_pred
def pred_single_tpv(test_content):
tpv_pred = []
for name in model_list_tpv:
model_path = os.path.join(current_script_directory,model_dict[name])
tpv_model = load(model_path)
pred = tpv_model.predict(test_content)
tpv_pred.append(pred[0])
return tpv_pred
def choose_model(name,data):
model_path = os.path.join(current_script_directory,model_dict[name])
model = load(model_path)
pred = model.predict(data)
return pred
def get_excel():
current_script_directory = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(current_script_directory,'data/TPV.xlsx')
test_data = pd.read_excel(data_path,sheet_name='sheet2')
test_data = test_data.drop('TPV', axis=1)
print(test_data.shape)
pred = choose_model('xgb_tpv',test_data)
test_data['TPV'] = pred
print(test_data)
if __name__ == "__main__":
# ssa_result = pred_single_ssa(test_content)
# result = [ssa_mae, ssa_r2,ssa_result]
# print(result)
# print(pred_single_tpv(test_content))
get_excel()