shanxi_test/参数调优rejie2.ipynb

1262 lines
45 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 29,
"id": "35268e29-6824-4e2d-983b-7e8259bd343f",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from hyperopt import hp, fmin, tpe, STATUS_OK, Trials\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "30a72889-e9da-45d3-bb03-cbe1b908a972",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>A</th>\n",
" <th>V</th>\n",
" <th>FC</th>\n",
" <th>C</th>\n",
" <th>H</th>\n",
" <th>N</th>\n",
" <th>S</th>\n",
" <th>O</th>\n",
" <th>H/C</th>\n",
" <th>O/C</th>\n",
" <th>N/C</th>\n",
" <th>Rt</th>\n",
" <th>Hr</th>\n",
" <th>dp</th>\n",
" <th>T</th>\n",
" <th>Tar</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25.56</td>\n",
" <td>45.32</td>\n",
" <td>54.68</td>\n",
" <td>69.13</td>\n",
" <td>4.91</td>\n",
" <td>1.13</td>\n",
" <td>0.81</td>\n",
" <td>24.02</td>\n",
" <td>0.852307</td>\n",
" <td>0.260596</td>\n",
" <td>0.014011</td>\n",
" <td>30.0</td>\n",
" <td>10.0</td>\n",
" <td>0.2</td>\n",
" <td>600</td>\n",
" <td>3.974958</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>15.26</td>\n",
" <td>45.11</td>\n",
" <td>54.89</td>\n",
" <td>63.80</td>\n",
" <td>5.11</td>\n",
" <td>1.56</td>\n",
" <td>0.88</td>\n",
" <td>28.65</td>\n",
" <td>0.961129</td>\n",
" <td>0.336795</td>\n",
" <td>0.020958</td>\n",
" <td>30.0</td>\n",
" <td>10.0</td>\n",
" <td>0.2</td>\n",
" <td>600</td>\n",
" <td>4.629865</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>9.92</td>\n",
" <td>35.16</td>\n",
" <td>64.84</td>\n",
" <td>80.25</td>\n",
" <td>4.66</td>\n",
" <td>0.91</td>\n",
" <td>0.56</td>\n",
" <td>13.62</td>\n",
" <td>0.696822</td>\n",
" <td>0.127290</td>\n",
" <td>0.009720</td>\n",
" <td>20.0</td>\n",
" <td>20.0</td>\n",
" <td>6.0</td>\n",
" <td>400</td>\n",
" <td>6.452928</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>9.92</td>\n",
" <td>35.16</td>\n",
" <td>64.84</td>\n",
" <td>80.25</td>\n",
" <td>4.66</td>\n",
" <td>0.91</td>\n",
" <td>0.56</td>\n",
" <td>13.62</td>\n",
" <td>0.696822</td>\n",
" <td>0.127290</td>\n",
" <td>0.009720</td>\n",
" <td>20.0</td>\n",
" <td>20.0</td>\n",
" <td>6.0</td>\n",
" <td>450</td>\n",
" <td>8.724672</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9.92</td>\n",
" <td>35.16</td>\n",
" <td>64.84</td>\n",
" <td>80.25</td>\n",
" <td>4.66</td>\n",
" <td>0.91</td>\n",
" <td>0.56</td>\n",
" <td>13.62</td>\n",
" <td>0.696822</td>\n",
" <td>0.127290</td>\n",
" <td>0.009720</td>\n",
" <td>20.0</td>\n",
" <td>20.0</td>\n",
" <td>6.0</td>\n",
" <td>500</td>\n",
" <td>10.075968</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" A V FC C H N S O H/C O/C \\\n",
"0 25.56 45.32 54.68 69.13 4.91 1.13 0.81 24.02 0.852307 0.260596 \n",
"1 15.26 45.11 54.89 63.80 5.11 1.56 0.88 28.65 0.961129 0.336795 \n",
"2 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n",
"3 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n",
"4 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n",
"\n",
" N/C Rt Hr dp T Tar \n",
"0 0.014011 30.0 10.0 0.2 600 3.974958 \n",
"1 0.020958 30.0 10.0 0.2 600 4.629865 \n",
"2 0.009720 20.0 20.0 6.0 400 6.452928 \n",
"3 0.009720 20.0 20.0 6.0 450 8.724672 \n",
"4 0.009720 20.0 20.0 6.0 500 10.075968 "
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_excel('/mnt/tanzk/mxx/rejie2.xlsx')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "652c035d-937c-4b6b-bbe3-b5ed9ea2ca92",
"metadata": {},
"outputs": [],
"source": [
"out_cols = ['Tar']"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "9f00267b-0cf5-4cfc-97ef-8c15cd8d262d",
"metadata": {},
"outputs": [],
"source": [
"feature_cols = [x for x in data.columns if x not in out_cols]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "676d77b7-605b-4a2a-acc0-8d625324e1c2",
"metadata": {},
"outputs": [],
"source": [
"train_data = data.reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "7bc4dd3d-3285-4bea-bdd2-c55816674722",
"metadata": {},
"outputs": [],
"source": [
"import xgboost as xgb\n",
"\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "52f19cb8-424e-47eb-a68d-f107a5d194bd",
"metadata": {},
"outputs": [],
"source": [
"# 定义超参数的搜索空间\n",
"# space = {\n",
"# 'eta': hp.loguniform('eta', -5, 0), # 学习率,搜索范围是 [1e-5, 1]\n",
"# 'max_depth': hp.choice('max_depth', range(5, 30)), # 树的最大深度,搜索范围是 [1, 10]\n",
"# 'min_child_weight': hp.uniform('min_child_weight', 0, 10), # 子节点最小的权重和\n",
"# 'gamma': hp.loguniform('gamma', -5, 0), # 叶子节点分裂所需的最小损失减少\n",
"# 'subsample': hp.uniform('subsample', 0.5, 1), # 训练集的采样率\n",
"# 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 1), # 特征的采样率\n",
"# }\n",
"space = {\n",
" 'eta': hp.loguniform('eta', -4, 0.5), # 学习率,搜索范围是 [1e-5, 1]\n",
" 'max_depth': hp.choice('max_depth', range(5, 50)), # 树的最大深度,搜索范围是 [1, 10]\n",
" 'min_child_weight': hp.uniform('min_child_weight', 0, 20), # 子节点最小的权重和\n",
" 'gamma': hp.loguniform('gamma', -3, 0.7), # 叶子节点分裂所需的最小损失减少\n",
" 'subsample': hp.uniform('subsample', 0.5, 0.9), # 训练集的采样率\n",
" 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 0.9), # 特征的采样率\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "b4209d03-b608-4520-9e40-79e93b358c05",
"metadata": {},
"outputs": [],
"source": [
"# 划分训练集和测试集\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], \n",
" train_data[out_cols], \n",
" test_size=0.1, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "e0e3e9fd-f85d-4370-8b7d-129556624721",
"metadata": {},
"outputs": [],
"source": [
"# 定义目标函数,用于评估模型的性能\n",
"def objective(params):\n",
" # 创建决策树分类器实例\n",
" gbr = xgb.XGBRegressor(**params)\n",
" # 训练模型\n",
" gbr.fit(X_train, y_train)\n",
" # 使用模型进行预测\n",
" y_pred = gbr.predict(X_test)\n",
" mae = mean_absolute_error(y_test, y_pred)\n",
" return {'loss': mae, 'status': STATUS_OK}"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "cc20c5e9-072f-4f6c-b74e-6042c54408ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [00:06<00:00, 15.37trial/s, best loss: 0.927005504531766]\n"
]
}
],
"source": [
"# 创建 Trials 对象来记录搜索历史\n",
"trials = Trials()\n",
"\n",
"# 使用 fmin 函数进行超参数优化\n",
"best_params = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=100, trials=trials)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "498cff26-e385-48ac-82ed-5057ccbc370c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n"
]
}
],
"source": [
"print(best_params)\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5eeb907e-cdc1-434a-82d7-3b967287bbb3",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import KFold, train_test_split\n",
"kf = KFold(n_splits=10, shuffle=True, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "e629cc0e-a9f8-44d7-9958-092a34a5e579",
"metadata": {},
"outputs": [],
"source": [
"num_boost_round = 1000"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "8ddd834e-22b8-42dc-8387-b9983b52a71b",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "6e1074cd-88d1-41aa-8d6a-b11b3ae705e0",
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams[\"font.sans-serif\"] = [\"WenQuanYi Micro Hei\"] # 设置字体\n",
"plt.rcParams[\"axes.unicode_minus\"] = False # 正常显示负号"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "70490cbb-f9cf-4b3c-b476-37f3721d50e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE: 2.336, RMSE: 1.5284, MAE: 1.0563, MAPE: 17.41 %, R_2: 0.7993\n",
"MSE: 2.1045, RMSE: 1.4507, MAE: 0.9808, MAPE: 14.64 %, R_2: 0.7211\n",
"MSE: 1.7837, RMSE: 1.3356, MAE: 0.9841, MAPE: 12.31 %, R_2: 0.7866\n",
"MSE: 3.7098, RMSE: 1.9261, MAE: 1.0768, MAPE: 21.11 %, R_2: 0.7009\n",
"MSE: 1.5407, RMSE: 1.2412, MAE: 0.7912, MAPE: 13.12 %, R_2: 0.8225\n",
"MSE: 2.5428, RMSE: 1.5946, MAE: 1.0797, MAPE: 15.42 %, R_2: 0.5668\n",
"MSE: 1.6981, RMSE: 1.3031, MAE: 0.9132, MAPE: 13.13 %, R_2: 0.7728\n",
"MSE: 1.6532, RMSE: 1.2858, MAE: 0.8018, MAPE: 9.68 %, R_2: 0.7663\n",
"MSE: 2.2342, RMSE: 1.4947, MAE: 1.1072, MAPE: 16.14 %, R_2: 0.7479\n",
"MSE: 1.6337, RMSE: 1.2782, MAE: 0.9285, MAPE: 12.27 %, R_2: 0.7293\n"
]
}
],
"source": [
"eva_list = list()\n",
"eva_cols = ['MSE', 'RMSE', 'MAE', 'MAPE', 'R2']\n",
"for (train_index, test_index) in kf.split(train_data):\n",
" train = train_data.loc[train_index]\n",
" valid = train_data.loc[test_index]\n",
" X_train, Y_train = train[feature_cols], train[out_cols]\n",
" X_valid, Y_valid = valid[feature_cols], valid[out_cols]\n",
" dtrain = xgb.DMatrix(X_train, Y_train)\n",
" dvalid = xgb.DMatrix(X_valid, Y_valid)\n",
" watchlist = [(dvalid, 'eval')]\n",
" gb_model = xgb.train(best_params, dtrain, num_boost_round, evals=watchlist,\n",
" early_stopping_rounds=100, verbose_eval=False)\n",
" y_pred = gb_model.predict(xgb.DMatrix(X_valid))\n",
" y_true = Y_valid.values\n",
" MSE = mean_squared_error(y_true, y_pred)\n",
" RMSE = np.sqrt(mean_squared_error(y_true, y_pred))\n",
" MAE = mean_absolute_error(y_true, y_pred)\n",
" MAPE = mean_absolute_percentage_error(y_true, y_pred)\n",
" R_2 = r2_score(y_true, y_pred)\n",
" print('MSE:', round(MSE, 4), end=', ')\n",
" print('RMSE:', round(RMSE, 4), end=', ')\n",
" print('MAE:', round(MAE, 4), end=', ')\n",
" print('MAPE:', round(MAPE*100, 2), '%', end=', ')\n",
" print('R_2:', round(R_2, 4)) #R方为负就说明拟合效果比平均值差\n",
" eva_list.append([MSE, RMSE, MAE, MAPE, R_2])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "10886d74-7543-49a6-b4f1-a31a3110fb99",
"metadata": {},
"outputs": [],
"source": [
"eva_df = pd.DataFrame.from_records(eva_list, columns=eva_cols)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "87718c37-27a4-44b4-b5ec-fbc6fdbc14fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MSE 2.123660\n",
"RMSE 1.443832\n",
"MAE 0.971950\n",
"MAPE 0.145207\n",
"R2 0.741365\n",
"dtype: float64"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eva_df.mean()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "ef6508a2-ca8a-4fd1-8a8b-0e8a91d01d87",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.861425600429374,\n",
" 1.3643407200656932,\n",
" 0.9369589057201204,\n",
" 0.14918013653153422,\n",
" 0.8400893043071764)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n",
" 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"gbr.fit(X_train, y_train)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2\n"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "f74ccd89-e1b0-44e8-913f-6b33e8115250",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.78639546955746,\n",
" 1.3365610609162082,\n",
" 0.9437850044721565,\n",
" 0.1579769923858331,\n",
" 0.846534966396966)"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n",
" 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"gbr.fit(X_train, y_train)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "8e223a2f-676b-4255-bd7e-b6ea8857a7e2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.5784385533129173,\n",
" 1.256359245324727,\n",
" 0.8012875783803378,\n",
" 0.12426418160928662,\n",
" 0.8644000560052363)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n",
" 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"gbr.fit(X_train, y_train)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "92bf1d55-a20d-4337-b8cf-9f96cb53e71e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-rmse:2.94022\n",
"[1]\tvalidation_0-rmse:2.32601\n",
"[2]\tvalidation_0-rmse:2.15546\n",
"[3]\tvalidation_0-rmse:1.97899\n",
"[4]\tvalidation_0-rmse:1.76293\n",
"[5]\tvalidation_0-rmse:1.66392\n",
"[6]\tvalidation_0-rmse:1.56808\n",
"[7]\tvalidation_0-rmse:1.57516\n",
"[8]\tvalidation_0-rmse:1.55810\n",
"[9]\tvalidation_0-rmse:1.49683\n",
"[10]\tvalidation_0-rmse:1.42146\n",
"[11]\tvalidation_0-rmse:1.38196\n",
"[12]\tvalidation_0-rmse:1.33497\n",
"[13]\tvalidation_0-rmse:1.29656\n",
"[14]\tvalidation_0-rmse:1.31387\n",
"[15]\tvalidation_0-rmse:1.29630\n",
"[16]\tvalidation_0-rmse:1.32641\n",
"[17]\tvalidation_0-rmse:1.34701\n",
"[18]\tvalidation_0-rmse:1.34493\n",
"[19]\tvalidation_0-rmse:1.34176\n",
"[20]\tvalidation_0-rmse:1.35030\n",
"[21]\tvalidation_0-rmse:1.35103\n",
"[22]\tvalidation_0-rmse:1.36426\n",
"[23]\tvalidation_0-rmse:1.38450\n",
"[24]\tvalidation_0-rmse:1.35743\n",
"[25]\tvalidation_0-rmse:1.34604\n",
"[26]\tvalidation_0-rmse:1.36367\n",
"[27]\tvalidation_0-rmse:1.34772\n",
"[28]\tvalidation_0-rmse:1.35335\n",
"[29]\tvalidation_0-rmse:1.37848\n",
"[30]\tvalidation_0-rmse:1.37084\n",
"[31]\tvalidation_0-rmse:1.36131\n",
"[32]\tvalidation_0-rmse:1.37337\n",
"[33]\tvalidation_0-rmse:1.36706\n",
"[34]\tvalidation_0-rmse:1.35412\n",
"[35]\tvalidation_0-rmse:1.35344\n",
"[36]\tvalidation_0-rmse:1.33891\n",
"[37]\tvalidation_0-rmse:1.31269\n",
"[38]\tvalidation_0-rmse:1.31461\n",
"[39]\tvalidation_0-rmse:1.31811\n",
"[40]\tvalidation_0-rmse:1.32056\n",
"[41]\tvalidation_0-rmse:1.32437\n",
"[42]\tvalidation_0-rmse:1.28383\n",
"[43]\tvalidation_0-rmse:1.28747\n",
"[44]\tvalidation_0-rmse:1.28136\n",
"[45]\tvalidation_0-rmse:1.30384\n",
"[46]\tvalidation_0-rmse:1.31680\n",
"[47]\tvalidation_0-rmse:1.32577\n",
"[48]\tvalidation_0-rmse:1.32435\n",
"[49]\tvalidation_0-rmse:1.34947\n",
"[50]\tvalidation_0-rmse:1.33535\n",
"[51]\tvalidation_0-rmse:1.33444\n",
"[52]\tvalidation_0-rmse:1.32391\n",
"[53]\tvalidation_0-rmse:1.32339\n",
"[54]\tvalidation_0-rmse:1.31571\n",
"[55]\tvalidation_0-rmse:1.32033\n",
"[56]\tvalidation_0-rmse:1.32455\n",
"[57]\tvalidation_0-rmse:1.33108\n",
"[58]\tvalidation_0-rmse:1.33065\n",
"[59]\tvalidation_0-rmse:1.33104\n",
"[60]\tvalidation_0-rmse:1.32995\n",
"[61]\tvalidation_0-rmse:1.33504\n",
"[62]\tvalidation_0-rmse:1.33599\n",
"[63]\tvalidation_0-rmse:1.33493\n",
"[64]\tvalidation_0-rmse:1.33562\n",
"[65]\tvalidation_0-rmse:1.33627\n",
"[66]\tvalidation_0-rmse:1.33196\n",
"[67]\tvalidation_0-rmse:1.32833\n",
"[68]\tvalidation_0-rmse:1.33790\n",
"[69]\tvalidation_0-rmse:1.33838\n",
"[70]\tvalidation_0-rmse:1.34104\n",
"[71]\tvalidation_0-rmse:1.34476\n",
"[72]\tvalidation_0-rmse:1.34435\n",
"[73]\tvalidation_0-rmse:1.33861\n",
"[74]\tvalidation_0-rmse:1.33732\n",
"[75]\tvalidation_0-rmse:1.34115\n",
"[76]\tvalidation_0-rmse:1.33682\n",
"[77]\tvalidation_0-rmse:1.33896\n",
"[78]\tvalidation_0-rmse:1.34560\n",
"[79]\tvalidation_0-rmse:1.34630\n",
"[80]\tvalidation_0-rmse:1.34347\n",
"[81]\tvalidation_0-rmse:1.34940\n",
"[82]\tvalidation_0-rmse:1.34326\n",
"[83]\tvalidation_0-rmse:1.34504\n",
"[84]\tvalidation_0-rmse:1.33780\n",
"[85]\tvalidation_0-rmse:1.34329\n",
"[86]\tvalidation_0-rmse:1.34396\n",
"[87]\tvalidation_0-rmse:1.34863\n",
"[88]\tvalidation_0-rmse:1.34721\n",
"[89]\tvalidation_0-rmse:1.35453\n",
"[90]\tvalidation_0-rmse:1.35999\n",
"[91]\tvalidation_0-rmse:1.34558\n",
"[92]\tvalidation_0-rmse:1.35019\n",
"[93]\tvalidation_0-rmse:1.35153\n",
"[94]\tvalidation_0-rmse:1.34330\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(1.6418860457242266,\n",
" 1.2813610130342763,\n",
" 0.8241156918341648,\n",
" 0.13793177033015197,\n",
" 0.8589494311459256)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.8997619766112827, 'eta': 0.48900376453173927, 'gamma': 0.09568323449358279, \n",
" 'max_depth': 29, 'min_child_weight': 2.0607020689885673, 'subsample': 0.5621662915587151}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"# gbr.fit(X_train, y_train)\n",
"gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "097fba6f-a6af-4529-9d09-589ccf84b991",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-rmse:3.03131\n",
"[1]\tvalidation_0-rmse:2.62574\n",
"[2]\tvalidation_0-rmse:2.41405\n",
"[3]\tvalidation_0-rmse:2.21943\n",
"[4]\tvalidation_0-rmse:2.06257\n",
"[5]\tvalidation_0-rmse:1.96808\n",
"[6]\tvalidation_0-rmse:1.73678\n",
"[7]\tvalidation_0-rmse:1.70331\n",
"[8]\tvalidation_0-rmse:1.64866\n",
"[9]\tvalidation_0-rmse:1.62014\n",
"[10]\tvalidation_0-rmse:1.55086\n",
"[11]\tvalidation_0-rmse:1.50959\n",
"[12]\tvalidation_0-rmse:1.47113\n",
"[13]\tvalidation_0-rmse:1.44401\n",
"[14]\tvalidation_0-rmse:1.42153\n",
"[15]\tvalidation_0-rmse:1.39003\n",
"[16]\tvalidation_0-rmse:1.35574\n",
"[17]\tvalidation_0-rmse:1.32175\n",
"[18]\tvalidation_0-rmse:1.31396\n",
"[19]\tvalidation_0-rmse:1.33739\n",
"[20]\tvalidation_0-rmse:1.33013\n",
"[21]\tvalidation_0-rmse:1.32880\n",
"[22]\tvalidation_0-rmse:1.33059\n",
"[23]\tvalidation_0-rmse:1.34923\n",
"[24]\tvalidation_0-rmse:1.31799\n",
"[25]\tvalidation_0-rmse:1.31411\n",
"[26]\tvalidation_0-rmse:1.30925\n",
"[27]\tvalidation_0-rmse:1.29766\n",
"[28]\tvalidation_0-rmse:1.30442\n",
"[29]\tvalidation_0-rmse:1.30966\n",
"[30]\tvalidation_0-rmse:1.28975\n",
"[31]\tvalidation_0-rmse:1.28810\n",
"[32]\tvalidation_0-rmse:1.28587\n",
"[33]\tvalidation_0-rmse:1.28566\n",
"[34]\tvalidation_0-rmse:1.26751\n",
"[35]\tvalidation_0-rmse:1.28569\n",
"[36]\tvalidation_0-rmse:1.27332\n",
"[37]\tvalidation_0-rmse:1.25467\n",
"[38]\tvalidation_0-rmse:1.25508\n",
"[39]\tvalidation_0-rmse:1.24264\n",
"[40]\tvalidation_0-rmse:1.24019\n",
"[41]\tvalidation_0-rmse:1.24022\n",
"[42]\tvalidation_0-rmse:1.24175\n",
"[43]\tvalidation_0-rmse:1.23122\n",
"[44]\tvalidation_0-rmse:1.22942\n",
"[45]\tvalidation_0-rmse:1.24364\n",
"[46]\tvalidation_0-rmse:1.22211\n",
"[47]\tvalidation_0-rmse:1.24425\n",
"[48]\tvalidation_0-rmse:1.23057\n",
"[49]\tvalidation_0-rmse:1.24111\n",
"[50]\tvalidation_0-rmse:1.23988\n",
"[51]\tvalidation_0-rmse:1.23965\n",
"[52]\tvalidation_0-rmse:1.23850\n",
"[53]\tvalidation_0-rmse:1.23933\n",
"[54]\tvalidation_0-rmse:1.23724\n",
"[55]\tvalidation_0-rmse:1.23362\n",
"[56]\tvalidation_0-rmse:1.23369\n",
"[57]\tvalidation_0-rmse:1.23845\n",
"[58]\tvalidation_0-rmse:1.23947\n",
"[59]\tvalidation_0-rmse:1.23491\n",
"[60]\tvalidation_0-rmse:1.23482\n",
"[61]\tvalidation_0-rmse:1.23441\n",
"[62]\tvalidation_0-rmse:1.23183\n",
"[63]\tvalidation_0-rmse:1.23244\n",
"[64]\tvalidation_0-rmse:1.23290\n",
"[65]\tvalidation_0-rmse:1.23486\n",
"[66]\tvalidation_0-rmse:1.23535\n",
"[67]\tvalidation_0-rmse:1.23568\n",
"[68]\tvalidation_0-rmse:1.23613\n",
"[69]\tvalidation_0-rmse:1.23555\n",
"[70]\tvalidation_0-rmse:1.23506\n",
"[71]\tvalidation_0-rmse:1.23512\n",
"[72]\tvalidation_0-rmse:1.23276\n",
"[73]\tvalidation_0-rmse:1.23287\n",
"[74]\tvalidation_0-rmse:1.23235\n",
"[75]\tvalidation_0-rmse:1.23922\n",
"[76]\tvalidation_0-rmse:1.23873\n",
"[77]\tvalidation_0-rmse:1.22865\n",
"[78]\tvalidation_0-rmse:1.23035\n",
"[79]\tvalidation_0-rmse:1.24439\n",
"[80]\tvalidation_0-rmse:1.23953\n",
"[81]\tvalidation_0-rmse:1.24745\n",
"[82]\tvalidation_0-rmse:1.23214\n",
"[83]\tvalidation_0-rmse:1.23968\n",
"[84]\tvalidation_0-rmse:1.24333\n",
"[85]\tvalidation_0-rmse:1.24434\n",
"[86]\tvalidation_0-rmse:1.25339\n",
"[87]\tvalidation_0-rmse:1.25396\n",
"[88]\tvalidation_0-rmse:1.25383\n",
"[89]\tvalidation_0-rmse:1.25276\n",
"[90]\tvalidation_0-rmse:1.25349\n",
"[91]\tvalidation_0-rmse:1.25226\n",
"[92]\tvalidation_0-rmse:1.25328\n",
"[93]\tvalidation_0-rmse:1.25284\n",
"[94]\tvalidation_0-rmse:1.25222\n",
"[95]\tvalidation_0-rmse:1.25212\n",
"[96]\tvalidation_0-rmse:1.25266\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(1.4935549430378676,\n",
" 1.2221108554619207,\n",
" 0.8199210432780543,\n",
" 0.1256990216917289,\n",
" 0.8716922073374574)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n",
" 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"# gbr.fit(X_train, y_train)\n",
"gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "9479b1af-a005-4abb-9664-9d3d4837133f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-rmse:2.94796\n",
"[1]\tvalidation_0-rmse:2.47933\n",
"[2]\tvalidation_0-rmse:2.32980\n",
"[3]\tvalidation_0-rmse:2.10618\n",
"[4]\tvalidation_0-rmse:1.97724\n",
"[5]\tvalidation_0-rmse:1.95411\n",
"[6]\tvalidation_0-rmse:1.84506\n",
"[7]\tvalidation_0-rmse:1.84115\n",
"[8]\tvalidation_0-rmse:1.82857\n",
"[9]\tvalidation_0-rmse:1.76881\n",
"[10]\tvalidation_0-rmse:1.66718\n",
"[11]\tvalidation_0-rmse:1.65162\n",
"[12]\tvalidation_0-rmse:1.61560\n",
"[13]\tvalidation_0-rmse:1.57779\n",
"[14]\tvalidation_0-rmse:1.54531\n",
"[15]\tvalidation_0-rmse:1.51655\n",
"[16]\tvalidation_0-rmse:1.49713\n",
"[17]\tvalidation_0-rmse:1.46052\n",
"[18]\tvalidation_0-rmse:1.42253\n",
"[19]\tvalidation_0-rmse:1.40742\n",
"[20]\tvalidation_0-rmse:1.36721\n",
"[21]\tvalidation_0-rmse:1.35093\n",
"[22]\tvalidation_0-rmse:1.35702\n",
"[23]\tvalidation_0-rmse:1.33818\n",
"[24]\tvalidation_0-rmse:1.32813\n",
"[25]\tvalidation_0-rmse:1.35230\n",
"[26]\tvalidation_0-rmse:1.34903\n",
"[27]\tvalidation_0-rmse:1.34344\n",
"[28]\tvalidation_0-rmse:1.34388\n",
"[29]\tvalidation_0-rmse:1.35948\n",
"[30]\tvalidation_0-rmse:1.33423\n",
"[31]\tvalidation_0-rmse:1.35954\n",
"[32]\tvalidation_0-rmse:1.35215\n",
"[33]\tvalidation_0-rmse:1.33783\n",
"[34]\tvalidation_0-rmse:1.31944\n",
"[35]\tvalidation_0-rmse:1.32828\n",
"[36]\tvalidation_0-rmse:1.29862\n",
"[37]\tvalidation_0-rmse:1.27408\n",
"[38]\tvalidation_0-rmse:1.28541\n",
"[39]\tvalidation_0-rmse:1.27269\n",
"[40]\tvalidation_0-rmse:1.26776\n",
"[41]\tvalidation_0-rmse:1.26017\n",
"[42]\tvalidation_0-rmse:1.29093\n",
"[43]\tvalidation_0-rmse:1.28878\n",
"[44]\tvalidation_0-rmse:1.29039\n",
"[45]\tvalidation_0-rmse:1.28695\n",
"[46]\tvalidation_0-rmse:1.28674\n",
"[47]\tvalidation_0-rmse:1.29803\n",
"[48]\tvalidation_0-rmse:1.28585\n",
"[49]\tvalidation_0-rmse:1.29421\n",
"[50]\tvalidation_0-rmse:1.29601\n",
"[51]\tvalidation_0-rmse:1.31381\n",
"[52]\tvalidation_0-rmse:1.32534\n",
"[53]\tvalidation_0-rmse:1.32683\n",
"[54]\tvalidation_0-rmse:1.31673\n",
"[55]\tvalidation_0-rmse:1.32776\n",
"[56]\tvalidation_0-rmse:1.32463\n",
"[57]\tvalidation_0-rmse:1.32801\n",
"[58]\tvalidation_0-rmse:1.32713\n",
"[59]\tvalidation_0-rmse:1.33055\n",
"[60]\tvalidation_0-rmse:1.33246\n",
"[61]\tvalidation_0-rmse:1.32767\n",
"[62]\tvalidation_0-rmse:1.33396\n",
"[63]\tvalidation_0-rmse:1.33038\n",
"[64]\tvalidation_0-rmse:1.32723\n",
"[65]\tvalidation_0-rmse:1.32445\n",
"[66]\tvalidation_0-rmse:1.32666\n",
"[67]\tvalidation_0-rmse:1.32470\n",
"[68]\tvalidation_0-rmse:1.32556\n",
"[69]\tvalidation_0-rmse:1.32294\n",
"[70]\tvalidation_0-rmse:1.32810\n",
"[71]\tvalidation_0-rmse:1.32573\n",
"[72]\tvalidation_0-rmse:1.33002\n",
"[73]\tvalidation_0-rmse:1.33069\n",
"[74]\tvalidation_0-rmse:1.33042\n",
"[75]\tvalidation_0-rmse:1.33060\n",
"[76]\tvalidation_0-rmse:1.33120\n",
"[77]\tvalidation_0-rmse:1.33269\n",
"[78]\tvalidation_0-rmse:1.32919\n",
"[79]\tvalidation_0-rmse:1.32429\n",
"[80]\tvalidation_0-rmse:1.32078\n",
"[81]\tvalidation_0-rmse:1.31978\n",
"[82]\tvalidation_0-rmse:1.30738\n",
"[83]\tvalidation_0-rmse:1.30778\n",
"[84]\tvalidation_0-rmse:1.31087\n",
"[85]\tvalidation_0-rmse:1.30729\n",
"[86]\tvalidation_0-rmse:1.30339\n",
"[87]\tvalidation_0-rmse:1.30681\n",
"[88]\tvalidation_0-rmse:1.30676\n",
"[89]\tvalidation_0-rmse:1.30625\n",
"[90]\tvalidation_0-rmse:1.30712\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(1.5880184359288727,\n",
" 1.2601660350639803,\n",
" 0.9411508613501586,\n",
" 0.1599000195341022,\n",
" 0.8635770708193553)"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n",
" 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "ae2c8a8a-1333-487e-9c2a-fee46bb9c934",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-rmse:2.99087\n",
"[1]\tvalidation_0-rmse:2.66375\n",
"[2]\tvalidation_0-rmse:2.40788\n",
"[3]\tvalidation_0-rmse:2.21494\n",
"[4]\tvalidation_0-rmse:2.08983\n",
"[5]\tvalidation_0-rmse:2.00727\n",
"[6]\tvalidation_0-rmse:1.85305\n",
"[7]\tvalidation_0-rmse:1.78475\n",
"[8]\tvalidation_0-rmse:1.70319\n",
"[9]\tvalidation_0-rmse:1.65511\n",
"[10]\tvalidation_0-rmse:1.58430\n",
"[11]\tvalidation_0-rmse:1.55495\n",
"[12]\tvalidation_0-rmse:1.53684\n",
"[13]\tvalidation_0-rmse:1.48998\n",
"[14]\tvalidation_0-rmse:1.48409\n",
"[15]\tvalidation_0-rmse:1.42885\n",
"[16]\tvalidation_0-rmse:1.39821\n",
"[17]\tvalidation_0-rmse:1.36639\n",
"[18]\tvalidation_0-rmse:1.36579\n",
"[19]\tvalidation_0-rmse:1.36524\n",
"[20]\tvalidation_0-rmse:1.37165\n",
"[21]\tvalidation_0-rmse:1.37061\n",
"[22]\tvalidation_0-rmse:1.37467\n",
"[23]\tvalidation_0-rmse:1.36140\n",
"[24]\tvalidation_0-rmse:1.34469\n",
"[25]\tvalidation_0-rmse:1.34300\n",
"[26]\tvalidation_0-rmse:1.34283\n",
"[27]\tvalidation_0-rmse:1.33821\n",
"[28]\tvalidation_0-rmse:1.33426\n",
"[29]\tvalidation_0-rmse:1.35185\n",
"[30]\tvalidation_0-rmse:1.34884\n",
"[31]\tvalidation_0-rmse:1.35015\n",
"[32]\tvalidation_0-rmse:1.36925\n",
"[33]\tvalidation_0-rmse:1.36742\n",
"[34]\tvalidation_0-rmse:1.36185\n",
"[35]\tvalidation_0-rmse:1.38217\n",
"[36]\tvalidation_0-rmse:1.37213\n",
"[37]\tvalidation_0-rmse:1.35339\n",
"[38]\tvalidation_0-rmse:1.35314\n",
"[39]\tvalidation_0-rmse:1.35475\n",
"[40]\tvalidation_0-rmse:1.34997\n",
"[41]\tvalidation_0-rmse:1.33195\n",
"[42]\tvalidation_0-rmse:1.33518\n",
"[43]\tvalidation_0-rmse:1.33585\n",
"[44]\tvalidation_0-rmse:1.33598\n",
"[45]\tvalidation_0-rmse:1.34456\n",
"[46]\tvalidation_0-rmse:1.33476\n",
"[47]\tvalidation_0-rmse:1.35722\n",
"[48]\tvalidation_0-rmse:1.36327\n",
"[49]\tvalidation_0-rmse:1.36469\n",
"[50]\tvalidation_0-rmse:1.35926\n",
"[51]\tvalidation_0-rmse:1.36052\n",
"[52]\tvalidation_0-rmse:1.36092\n",
"[53]\tvalidation_0-rmse:1.36165\n",
"[54]\tvalidation_0-rmse:1.34672\n",
"[55]\tvalidation_0-rmse:1.35078\n",
"[56]\tvalidation_0-rmse:1.35164\n",
"[57]\tvalidation_0-rmse:1.35124\n",
"[58]\tvalidation_0-rmse:1.34962\n",
"[59]\tvalidation_0-rmse:1.35293\n",
"[60]\tvalidation_0-rmse:1.34839\n",
"[61]\tvalidation_0-rmse:1.33587\n",
"[62]\tvalidation_0-rmse:1.32711\n",
"[63]\tvalidation_0-rmse:1.32892\n",
"[64]\tvalidation_0-rmse:1.33008\n",
"[65]\tvalidation_0-rmse:1.33023\n",
"[66]\tvalidation_0-rmse:1.33082\n",
"[67]\tvalidation_0-rmse:1.33207\n",
"[68]\tvalidation_0-rmse:1.33262\n",
"[69]\tvalidation_0-rmse:1.33185\n",
"[70]\tvalidation_0-rmse:1.33114\n",
"[71]\tvalidation_0-rmse:1.33287\n",
"[72]\tvalidation_0-rmse:1.34803\n",
"[73]\tvalidation_0-rmse:1.35223\n",
"[74]\tvalidation_0-rmse:1.34266\n",
"[75]\tvalidation_0-rmse:1.34423\n",
"[76]\tvalidation_0-rmse:1.34351\n",
"[77]\tvalidation_0-rmse:1.33684\n",
"[78]\tvalidation_0-rmse:1.33450\n",
"[79]\tvalidation_0-rmse:1.35080\n",
"[80]\tvalidation_0-rmse:1.34307\n",
"[81]\tvalidation_0-rmse:1.33828\n",
"[82]\tvalidation_0-rmse:1.33786\n",
"[83]\tvalidation_0-rmse:1.33990\n",
"[84]\tvalidation_0-rmse:1.34383\n",
"[85]\tvalidation_0-rmse:1.34400\n",
"[86]\tvalidation_0-rmse:1.34246\n",
"[87]\tvalidation_0-rmse:1.34125\n",
"[88]\tvalidation_0-rmse:1.34723\n",
"[89]\tvalidation_0-rmse:1.34835\n",
"[90]\tvalidation_0-rmse:1.34827\n",
"[91]\tvalidation_0-rmse:1.34661\n",
"[92]\tvalidation_0-rmse:1.34880\n",
"[93]\tvalidation_0-rmse:1.34823\n",
"[94]\tvalidation_0-rmse:1.36337\n",
"[95]\tvalidation_0-rmse:1.36356\n",
"[96]\tvalidation_0-rmse:1.36404\n",
"[97]\tvalidation_0-rmse:1.36244\n",
"[98]\tvalidation_0-rmse:1.36225\n",
"[99]\tvalidation_0-rmse:1.36267\n",
"[100]\tvalidation_0-rmse:1.36269\n",
"[101]\tvalidation_0-rmse:1.36312\n",
"[102]\tvalidation_0-rmse:1.37098\n",
"[103]\tvalidation_0-rmse:1.37105\n",
"[104]\tvalidation_0-rmse:1.37414\n",
"[105]\tvalidation_0-rmse:1.37391\n",
"[106]\tvalidation_0-rmse:1.37318\n",
"[107]\tvalidation_0-rmse:1.36753\n",
"[108]\tvalidation_0-rmse:1.36538\n",
"[109]\tvalidation_0-rmse:1.36617\n",
"[110]\tvalidation_0-rmse:1.36542\n",
"[111]\tvalidation_0-rmse:1.36646\n",
"[112]\tvalidation_0-rmse:1.38288\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(1.7612323261291336,\n",
" 1.3271142852554687,\n",
" 0.9281123720415317,\n",
" 0.14669084489239353,\n",
" 0.8486966728710329)"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import xgboost as xgb\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Extract best parameters\n",
"best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n",
" 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n",
"\n",
"# Re-train the model with the best parameters and evaluate\n",
"X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n",
"gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n",
"gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n",
"\n",
"# Predict and evaluate\n",
"y_pred = gbr.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"mape = mean_absolute_percentage_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"mse, rmse, mae, mape, r2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27880e3b-0149-4611-abda-9e06442704e4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}