commit dd96a698af11a45cd75c671ee8cd83b503420ce1 Author: tanzekun Date: Tue Dec 3 10:15:18 2024 +0800 上传文件至 / diff --git a/参数调优rejie2.ipynb b/参数调优rejie2.ipynb new file mode 100644 index 0000000..d83d508 --- /dev/null +++ b/参数调优rejie2.ipynb @@ -0,0 +1,1261 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 29, + "id": "35268e29-6824-4e2d-983b-7e8259bd343f", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from hyperopt import hp, fmin, tpe, STATUS_OK, Trials\n", + "from sklearn.model_selection import train_test_split\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "30a72889-e9da-45d3-bb03-cbe1b908a972", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
AVFCCHNSOH/CO/CN/CRtHrdpTTar
025.5645.3254.6869.134.911.130.8124.020.8523070.2605960.01401130.010.00.26003.974958
115.2645.1154.8963.805.111.560.8828.650.9611290.3367950.02095830.010.00.26004.629865
29.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.04006.452928
39.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.04508.724672
49.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.050010.075968
\n", + "
" + ], + "text/plain": [ + " A V FC C H N S O H/C O/C \\\n", + "0 25.56 45.32 54.68 69.13 4.91 1.13 0.81 24.02 0.852307 0.260596 \n", + "1 15.26 45.11 54.89 63.80 5.11 1.56 0.88 28.65 0.961129 0.336795 \n", + "2 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", + "3 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", + "4 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", + "\n", + " N/C Rt Hr dp T Tar \n", + "0 0.014011 30.0 10.0 0.2 600 3.974958 \n", + "1 0.020958 30.0 10.0 0.2 600 4.629865 \n", + "2 0.009720 20.0 20.0 6.0 400 6.452928 \n", + "3 0.009720 20.0 20.0 6.0 450 8.724672 \n", + "4 0.009720 20.0 20.0 6.0 500 10.075968 " + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = pd.read_excel('/mnt/tanzk/mxx/rejie2.xlsx')\n", + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "652c035d-937c-4b6b-bbe3-b5ed9ea2ca92", + "metadata": {}, + "outputs": [], + "source": [ + "out_cols = ['Tar']" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "9f00267b-0cf5-4cfc-97ef-8c15cd8d262d", + "metadata": {}, + "outputs": [], + "source": [ + "feature_cols = [x for x in data.columns if x not in out_cols]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "676d77b7-605b-4a2a-acc0-8d625324e1c2", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = data.reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "7bc4dd3d-3285-4bea-bdd2-c55816674722", + "metadata": {}, + "outputs": [], + "source": [ + "import xgboost as xgb\n", + "\n", + "from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "52f19cb8-424e-47eb-a68d-f107a5d194bd", + "metadata": {}, + "outputs": [], + "source": [ + "# 定义超参数的搜索空间\n", + "# space = {\n", + "# 'eta': hp.loguniform('eta', -5, 0), # 学习率,搜索范围是 [1e-5, 1]\n", + "# 'max_depth': hp.choice('max_depth', range(5, 30)), # 树的最大深度,搜索范围是 [1, 10]\n", + "# 'min_child_weight': hp.uniform('min_child_weight', 0, 10), # 子节点最小的权重和\n", + "# 'gamma': hp.loguniform('gamma', -5, 0), # 叶子节点分裂所需的最小损失减少\n", + "# 'subsample': hp.uniform('subsample', 0.5, 1), # 训练集的采样率\n", + "# 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 1), # 特征的采样率\n", + "# }\n", + "space = {\n", + " 'eta': hp.loguniform('eta', -4, 0.5), # 学习率,搜索范围是 [1e-5, 1]\n", + " 'max_depth': hp.choice('max_depth', range(5, 50)), # 树的最大深度,搜索范围是 [1, 10]\n", + " 'min_child_weight': hp.uniform('min_child_weight', 0, 20), # 子节点最小的权重和\n", + " 'gamma': hp.loguniform('gamma', -3, 0.7), # 叶子节点分裂所需的最小损失减少\n", + " 'subsample': hp.uniform('subsample', 0.5, 0.9), # 训练集的采样率\n", + " 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 0.9), # 特征的采样率\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "b4209d03-b608-4520-9e40-79e93b358c05", + "metadata": {}, + "outputs": [], + "source": [ + "# 划分训练集和测试集\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], \n", + " train_data[out_cols], \n", + " test_size=0.1, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "e0e3e9fd-f85d-4370-8b7d-129556624721", + "metadata": {}, + "outputs": [], + "source": [ + "# 定义目标函数,用于评估模型的性能\n", + "def objective(params):\n", + " # 创建决策树分类器实例\n", + " gbr = xgb.XGBRegressor(**params)\n", + " # 训练模型\n", + " gbr.fit(X_train, y_train)\n", + " # 使用模型进行预测\n", + " y_pred = gbr.predict(X_test)\n", + " mae = mean_absolute_error(y_test, y_pred)\n", + " return {'loss': mae, 'status': STATUS_OK}" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "cc20c5e9-072f-4f6c-b74e-6042c54408ce", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:06<00:00, 15.37trial/s, best loss: 0.927005504531766]\n" + ] + } + ], + "source": [ + "# 创建 Trials 对象来记录搜索历史\n", + "trials = Trials()\n", + "\n", + "# 使用 fmin 函数进行超参数优化\n", + "best_params = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=100, trials=trials)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "498cff26-e385-48ac-82ed-5057ccbc370c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n" + ] + } + ], + "source": [ + "print(best_params)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "5eeb907e-cdc1-434a-82d7-3b967287bbb3", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import KFold, train_test_split\n", + "kf = KFold(n_splits=10, shuffle=True, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "e629cc0e-a9f8-44d7-9958-092a34a5e579", + "metadata": {}, + "outputs": [], + "source": [ + "num_boost_round = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "8ddd834e-22b8-42dc-8387-b9983b52a71b", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "6e1074cd-88d1-41aa-8d6a-b11b3ae705e0", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams[\"font.sans-serif\"] = [\"WenQuanYi Micro Hei\"] # 设置字体\n", + "plt.rcParams[\"axes.unicode_minus\"] = False # 正常显示负号" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "70490cbb-f9cf-4b3c-b476-37f3721d50e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE: 2.336, RMSE: 1.5284, MAE: 1.0563, MAPE: 17.41 %, R_2: 0.7993\n", + "MSE: 2.1045, RMSE: 1.4507, MAE: 0.9808, MAPE: 14.64 %, R_2: 0.7211\n", + "MSE: 1.7837, RMSE: 1.3356, MAE: 0.9841, MAPE: 12.31 %, R_2: 0.7866\n", + "MSE: 3.7098, RMSE: 1.9261, MAE: 1.0768, MAPE: 21.11 %, R_2: 0.7009\n", + "MSE: 1.5407, RMSE: 1.2412, MAE: 0.7912, MAPE: 13.12 %, R_2: 0.8225\n", + "MSE: 2.5428, RMSE: 1.5946, MAE: 1.0797, MAPE: 15.42 %, R_2: 0.5668\n", + "MSE: 1.6981, RMSE: 1.3031, MAE: 0.9132, MAPE: 13.13 %, R_2: 0.7728\n", + "MSE: 1.6532, RMSE: 1.2858, MAE: 0.8018, MAPE: 9.68 %, R_2: 0.7663\n", + "MSE: 2.2342, RMSE: 1.4947, MAE: 1.1072, MAPE: 16.14 %, R_2: 0.7479\n", + "MSE: 1.6337, RMSE: 1.2782, MAE: 0.9285, MAPE: 12.27 %, R_2: 0.7293\n" + ] + } + ], + "source": [ + "eva_list = list()\n", + "eva_cols = ['MSE', 'RMSE', 'MAE', 'MAPE', 'R2']\n", + "for (train_index, test_index) in kf.split(train_data):\n", + " train = train_data.loc[train_index]\n", + " valid = train_data.loc[test_index]\n", + " X_train, Y_train = train[feature_cols], train[out_cols]\n", + " X_valid, Y_valid = valid[feature_cols], valid[out_cols]\n", + " dtrain = xgb.DMatrix(X_train, Y_train)\n", + " dvalid = xgb.DMatrix(X_valid, Y_valid)\n", + " watchlist = [(dvalid, 'eval')]\n", + " gb_model = xgb.train(best_params, dtrain, num_boost_round, evals=watchlist,\n", + " early_stopping_rounds=100, verbose_eval=False)\n", + " y_pred = gb_model.predict(xgb.DMatrix(X_valid))\n", + " y_true = Y_valid.values\n", + " MSE = mean_squared_error(y_true, y_pred)\n", + " RMSE = np.sqrt(mean_squared_error(y_true, y_pred))\n", + " MAE = mean_absolute_error(y_true, y_pred)\n", + " MAPE = mean_absolute_percentage_error(y_true, y_pred)\n", + " R_2 = r2_score(y_true, y_pred)\n", + " print('MSE:', round(MSE, 4), end=', ')\n", + " print('RMSE:', round(RMSE, 4), end=', ')\n", + " print('MAE:', round(MAE, 4), end=', ')\n", + " print('MAPE:', round(MAPE*100, 2), '%', end=', ')\n", + " print('R_2:', round(R_2, 4)) #R方为负就说明拟合效果比平均值差\n", + " eva_list.append([MSE, RMSE, MAE, MAPE, R_2])" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "10886d74-7543-49a6-b4f1-a31a3110fb99", + "metadata": {}, + "outputs": [], + "source": [ + "eva_df = pd.DataFrame.from_records(eva_list, columns=eva_cols)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "87718c37-27a4-44b4-b5ec-fbc6fdbc14fe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MSE 2.123660\n", + "RMSE 1.443832\n", + "MAE 0.971950\n", + "MAPE 0.145207\n", + "R2 0.741365\n", + "dtype: float64" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eva_df.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "ef6508a2-ca8a-4fd1-8a8b-0e8a91d01d87", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1.861425600429374,\n", + " 1.3643407200656932,\n", + " 0.9369589057201204,\n", + " 0.14918013653153422,\n", + " 0.8400893043071764)" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n", + " 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "gbr.fit(X_train, y_train)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "f74ccd89-e1b0-44e8-913f-6b33e8115250", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1.78639546955746,\n", + " 1.3365610609162082,\n", + " 0.9437850044721565,\n", + " 0.1579769923858331,\n", + " 0.846534966396966)" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n", + " 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "gbr.fit(X_train, y_train)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "8e223a2f-676b-4255-bd7e-b6ea8857a7e2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1.5784385533129173,\n", + " 1.256359245324727,\n", + " 0.8012875783803378,\n", + " 0.12426418160928662,\n", + " 0.8644000560052363)" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n", + " 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "gbr.fit(X_train, y_train)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "92bf1d55-a20d-4337-b8cf-9f96cb53e71e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalidation_0-rmse:2.94022\n", + "[1]\tvalidation_0-rmse:2.32601\n", + "[2]\tvalidation_0-rmse:2.15546\n", + "[3]\tvalidation_0-rmse:1.97899\n", + "[4]\tvalidation_0-rmse:1.76293\n", + "[5]\tvalidation_0-rmse:1.66392\n", + "[6]\tvalidation_0-rmse:1.56808\n", + "[7]\tvalidation_0-rmse:1.57516\n", + "[8]\tvalidation_0-rmse:1.55810\n", + "[9]\tvalidation_0-rmse:1.49683\n", + "[10]\tvalidation_0-rmse:1.42146\n", + "[11]\tvalidation_0-rmse:1.38196\n", + "[12]\tvalidation_0-rmse:1.33497\n", + "[13]\tvalidation_0-rmse:1.29656\n", + "[14]\tvalidation_0-rmse:1.31387\n", + "[15]\tvalidation_0-rmse:1.29630\n", + "[16]\tvalidation_0-rmse:1.32641\n", + "[17]\tvalidation_0-rmse:1.34701\n", + "[18]\tvalidation_0-rmse:1.34493\n", + "[19]\tvalidation_0-rmse:1.34176\n", + "[20]\tvalidation_0-rmse:1.35030\n", + "[21]\tvalidation_0-rmse:1.35103\n", + "[22]\tvalidation_0-rmse:1.36426\n", + "[23]\tvalidation_0-rmse:1.38450\n", + "[24]\tvalidation_0-rmse:1.35743\n", + "[25]\tvalidation_0-rmse:1.34604\n", + "[26]\tvalidation_0-rmse:1.36367\n", + "[27]\tvalidation_0-rmse:1.34772\n", + "[28]\tvalidation_0-rmse:1.35335\n", + "[29]\tvalidation_0-rmse:1.37848\n", + "[30]\tvalidation_0-rmse:1.37084\n", + "[31]\tvalidation_0-rmse:1.36131\n", + "[32]\tvalidation_0-rmse:1.37337\n", + "[33]\tvalidation_0-rmse:1.36706\n", + "[34]\tvalidation_0-rmse:1.35412\n", + "[35]\tvalidation_0-rmse:1.35344\n", + "[36]\tvalidation_0-rmse:1.33891\n", + "[37]\tvalidation_0-rmse:1.31269\n", + "[38]\tvalidation_0-rmse:1.31461\n", + "[39]\tvalidation_0-rmse:1.31811\n", + "[40]\tvalidation_0-rmse:1.32056\n", + "[41]\tvalidation_0-rmse:1.32437\n", + "[42]\tvalidation_0-rmse:1.28383\n", + "[43]\tvalidation_0-rmse:1.28747\n", + "[44]\tvalidation_0-rmse:1.28136\n", + "[45]\tvalidation_0-rmse:1.30384\n", + "[46]\tvalidation_0-rmse:1.31680\n", + "[47]\tvalidation_0-rmse:1.32577\n", + "[48]\tvalidation_0-rmse:1.32435\n", + "[49]\tvalidation_0-rmse:1.34947\n", + "[50]\tvalidation_0-rmse:1.33535\n", + "[51]\tvalidation_0-rmse:1.33444\n", + "[52]\tvalidation_0-rmse:1.32391\n", + "[53]\tvalidation_0-rmse:1.32339\n", + "[54]\tvalidation_0-rmse:1.31571\n", + "[55]\tvalidation_0-rmse:1.32033\n", + "[56]\tvalidation_0-rmse:1.32455\n", + "[57]\tvalidation_0-rmse:1.33108\n", + "[58]\tvalidation_0-rmse:1.33065\n", + "[59]\tvalidation_0-rmse:1.33104\n", + "[60]\tvalidation_0-rmse:1.32995\n", + "[61]\tvalidation_0-rmse:1.33504\n", + "[62]\tvalidation_0-rmse:1.33599\n", + "[63]\tvalidation_0-rmse:1.33493\n", + "[64]\tvalidation_0-rmse:1.33562\n", + "[65]\tvalidation_0-rmse:1.33627\n", + "[66]\tvalidation_0-rmse:1.33196\n", + "[67]\tvalidation_0-rmse:1.32833\n", + "[68]\tvalidation_0-rmse:1.33790\n", + "[69]\tvalidation_0-rmse:1.33838\n", + "[70]\tvalidation_0-rmse:1.34104\n", + "[71]\tvalidation_0-rmse:1.34476\n", + "[72]\tvalidation_0-rmse:1.34435\n", + "[73]\tvalidation_0-rmse:1.33861\n", + "[74]\tvalidation_0-rmse:1.33732\n", + "[75]\tvalidation_0-rmse:1.34115\n", + "[76]\tvalidation_0-rmse:1.33682\n", + "[77]\tvalidation_0-rmse:1.33896\n", + "[78]\tvalidation_0-rmse:1.34560\n", + "[79]\tvalidation_0-rmse:1.34630\n", + "[80]\tvalidation_0-rmse:1.34347\n", + "[81]\tvalidation_0-rmse:1.34940\n", + "[82]\tvalidation_0-rmse:1.34326\n", + "[83]\tvalidation_0-rmse:1.34504\n", + "[84]\tvalidation_0-rmse:1.33780\n", + "[85]\tvalidation_0-rmse:1.34329\n", + "[86]\tvalidation_0-rmse:1.34396\n", + "[87]\tvalidation_0-rmse:1.34863\n", + "[88]\tvalidation_0-rmse:1.34721\n", + "[89]\tvalidation_0-rmse:1.35453\n", + "[90]\tvalidation_0-rmse:1.35999\n", + "[91]\tvalidation_0-rmse:1.34558\n", + "[92]\tvalidation_0-rmse:1.35019\n", + "[93]\tvalidation_0-rmse:1.35153\n", + "[94]\tvalidation_0-rmse:1.34330\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "(1.6418860457242266,\n", + " 1.2813610130342763,\n", + " 0.8241156918341648,\n", + " 0.13793177033015197,\n", + " 0.8589494311459256)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.8997619766112827, 'eta': 0.48900376453173927, 'gamma': 0.09568323449358279, \n", + " 'max_depth': 29, 'min_child_weight': 2.0607020689885673, 'subsample': 0.5621662915587151}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "# gbr.fit(X_train, y_train)\n", + "gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "097fba6f-a6af-4529-9d09-589ccf84b991", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalidation_0-rmse:3.03131\n", + "[1]\tvalidation_0-rmse:2.62574\n", + "[2]\tvalidation_0-rmse:2.41405\n", + "[3]\tvalidation_0-rmse:2.21943\n", + "[4]\tvalidation_0-rmse:2.06257\n", + "[5]\tvalidation_0-rmse:1.96808\n", + "[6]\tvalidation_0-rmse:1.73678\n", + "[7]\tvalidation_0-rmse:1.70331\n", + "[8]\tvalidation_0-rmse:1.64866\n", + "[9]\tvalidation_0-rmse:1.62014\n", + "[10]\tvalidation_0-rmse:1.55086\n", + "[11]\tvalidation_0-rmse:1.50959\n", + "[12]\tvalidation_0-rmse:1.47113\n", + "[13]\tvalidation_0-rmse:1.44401\n", + "[14]\tvalidation_0-rmse:1.42153\n", + "[15]\tvalidation_0-rmse:1.39003\n", + "[16]\tvalidation_0-rmse:1.35574\n", + "[17]\tvalidation_0-rmse:1.32175\n", + "[18]\tvalidation_0-rmse:1.31396\n", + "[19]\tvalidation_0-rmse:1.33739\n", + "[20]\tvalidation_0-rmse:1.33013\n", + "[21]\tvalidation_0-rmse:1.32880\n", + "[22]\tvalidation_0-rmse:1.33059\n", + "[23]\tvalidation_0-rmse:1.34923\n", + "[24]\tvalidation_0-rmse:1.31799\n", + "[25]\tvalidation_0-rmse:1.31411\n", + "[26]\tvalidation_0-rmse:1.30925\n", + "[27]\tvalidation_0-rmse:1.29766\n", + "[28]\tvalidation_0-rmse:1.30442\n", + "[29]\tvalidation_0-rmse:1.30966\n", + "[30]\tvalidation_0-rmse:1.28975\n", + "[31]\tvalidation_0-rmse:1.28810\n", + "[32]\tvalidation_0-rmse:1.28587\n", + "[33]\tvalidation_0-rmse:1.28566\n", + "[34]\tvalidation_0-rmse:1.26751\n", + "[35]\tvalidation_0-rmse:1.28569\n", + "[36]\tvalidation_0-rmse:1.27332\n", + "[37]\tvalidation_0-rmse:1.25467\n", + "[38]\tvalidation_0-rmse:1.25508\n", + "[39]\tvalidation_0-rmse:1.24264\n", + "[40]\tvalidation_0-rmse:1.24019\n", + "[41]\tvalidation_0-rmse:1.24022\n", + "[42]\tvalidation_0-rmse:1.24175\n", + "[43]\tvalidation_0-rmse:1.23122\n", + "[44]\tvalidation_0-rmse:1.22942\n", + "[45]\tvalidation_0-rmse:1.24364\n", + "[46]\tvalidation_0-rmse:1.22211\n", + "[47]\tvalidation_0-rmse:1.24425\n", + "[48]\tvalidation_0-rmse:1.23057\n", + "[49]\tvalidation_0-rmse:1.24111\n", + "[50]\tvalidation_0-rmse:1.23988\n", + "[51]\tvalidation_0-rmse:1.23965\n", + "[52]\tvalidation_0-rmse:1.23850\n", + "[53]\tvalidation_0-rmse:1.23933\n", + "[54]\tvalidation_0-rmse:1.23724\n", + "[55]\tvalidation_0-rmse:1.23362\n", + "[56]\tvalidation_0-rmse:1.23369\n", + "[57]\tvalidation_0-rmse:1.23845\n", + "[58]\tvalidation_0-rmse:1.23947\n", + "[59]\tvalidation_0-rmse:1.23491\n", + "[60]\tvalidation_0-rmse:1.23482\n", + "[61]\tvalidation_0-rmse:1.23441\n", + "[62]\tvalidation_0-rmse:1.23183\n", + "[63]\tvalidation_0-rmse:1.23244\n", + "[64]\tvalidation_0-rmse:1.23290\n", + "[65]\tvalidation_0-rmse:1.23486\n", + "[66]\tvalidation_0-rmse:1.23535\n", + "[67]\tvalidation_0-rmse:1.23568\n", + "[68]\tvalidation_0-rmse:1.23613\n", + "[69]\tvalidation_0-rmse:1.23555\n", + "[70]\tvalidation_0-rmse:1.23506\n", + "[71]\tvalidation_0-rmse:1.23512\n", + "[72]\tvalidation_0-rmse:1.23276\n", + "[73]\tvalidation_0-rmse:1.23287\n", + "[74]\tvalidation_0-rmse:1.23235\n", + "[75]\tvalidation_0-rmse:1.23922\n", + "[76]\tvalidation_0-rmse:1.23873\n", + "[77]\tvalidation_0-rmse:1.22865\n", + "[78]\tvalidation_0-rmse:1.23035\n", + "[79]\tvalidation_0-rmse:1.24439\n", + "[80]\tvalidation_0-rmse:1.23953\n", + "[81]\tvalidation_0-rmse:1.24745\n", + "[82]\tvalidation_0-rmse:1.23214\n", + "[83]\tvalidation_0-rmse:1.23968\n", + "[84]\tvalidation_0-rmse:1.24333\n", + "[85]\tvalidation_0-rmse:1.24434\n", + "[86]\tvalidation_0-rmse:1.25339\n", + "[87]\tvalidation_0-rmse:1.25396\n", + "[88]\tvalidation_0-rmse:1.25383\n", + "[89]\tvalidation_0-rmse:1.25276\n", + "[90]\tvalidation_0-rmse:1.25349\n", + "[91]\tvalidation_0-rmse:1.25226\n", + "[92]\tvalidation_0-rmse:1.25328\n", + "[93]\tvalidation_0-rmse:1.25284\n", + "[94]\tvalidation_0-rmse:1.25222\n", + "[95]\tvalidation_0-rmse:1.25212\n", + "[96]\tvalidation_0-rmse:1.25266\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "(1.4935549430378676,\n", + " 1.2221108554619207,\n", + " 0.8199210432780543,\n", + " 0.1256990216917289,\n", + " 0.8716922073374574)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n", + " 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "# gbr.fit(X_train, y_train)\n", + "gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "9479b1af-a005-4abb-9664-9d3d4837133f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalidation_0-rmse:2.94796\n", + "[1]\tvalidation_0-rmse:2.47933\n", + "[2]\tvalidation_0-rmse:2.32980\n", + "[3]\tvalidation_0-rmse:2.10618\n", + "[4]\tvalidation_0-rmse:1.97724\n", + "[5]\tvalidation_0-rmse:1.95411\n", + "[6]\tvalidation_0-rmse:1.84506\n", + "[7]\tvalidation_0-rmse:1.84115\n", + "[8]\tvalidation_0-rmse:1.82857\n", + "[9]\tvalidation_0-rmse:1.76881\n", + "[10]\tvalidation_0-rmse:1.66718\n", + "[11]\tvalidation_0-rmse:1.65162\n", + "[12]\tvalidation_0-rmse:1.61560\n", + "[13]\tvalidation_0-rmse:1.57779\n", + "[14]\tvalidation_0-rmse:1.54531\n", + "[15]\tvalidation_0-rmse:1.51655\n", + "[16]\tvalidation_0-rmse:1.49713\n", + "[17]\tvalidation_0-rmse:1.46052\n", + "[18]\tvalidation_0-rmse:1.42253\n", + "[19]\tvalidation_0-rmse:1.40742\n", + "[20]\tvalidation_0-rmse:1.36721\n", + "[21]\tvalidation_0-rmse:1.35093\n", + "[22]\tvalidation_0-rmse:1.35702\n", + "[23]\tvalidation_0-rmse:1.33818\n", + "[24]\tvalidation_0-rmse:1.32813\n", + "[25]\tvalidation_0-rmse:1.35230\n", + "[26]\tvalidation_0-rmse:1.34903\n", + "[27]\tvalidation_0-rmse:1.34344\n", + "[28]\tvalidation_0-rmse:1.34388\n", + "[29]\tvalidation_0-rmse:1.35948\n", + "[30]\tvalidation_0-rmse:1.33423\n", + "[31]\tvalidation_0-rmse:1.35954\n", + "[32]\tvalidation_0-rmse:1.35215\n", + "[33]\tvalidation_0-rmse:1.33783\n", + "[34]\tvalidation_0-rmse:1.31944\n", + "[35]\tvalidation_0-rmse:1.32828\n", + "[36]\tvalidation_0-rmse:1.29862\n", + "[37]\tvalidation_0-rmse:1.27408\n", + "[38]\tvalidation_0-rmse:1.28541\n", + "[39]\tvalidation_0-rmse:1.27269\n", + "[40]\tvalidation_0-rmse:1.26776\n", + "[41]\tvalidation_0-rmse:1.26017\n", + "[42]\tvalidation_0-rmse:1.29093\n", + "[43]\tvalidation_0-rmse:1.28878\n", + "[44]\tvalidation_0-rmse:1.29039\n", + "[45]\tvalidation_0-rmse:1.28695\n", + "[46]\tvalidation_0-rmse:1.28674\n", + "[47]\tvalidation_0-rmse:1.29803\n", + "[48]\tvalidation_0-rmse:1.28585\n", + "[49]\tvalidation_0-rmse:1.29421\n", + "[50]\tvalidation_0-rmse:1.29601\n", + "[51]\tvalidation_0-rmse:1.31381\n", + "[52]\tvalidation_0-rmse:1.32534\n", + "[53]\tvalidation_0-rmse:1.32683\n", + "[54]\tvalidation_0-rmse:1.31673\n", + "[55]\tvalidation_0-rmse:1.32776\n", + "[56]\tvalidation_0-rmse:1.32463\n", + "[57]\tvalidation_0-rmse:1.32801\n", + "[58]\tvalidation_0-rmse:1.32713\n", + "[59]\tvalidation_0-rmse:1.33055\n", + "[60]\tvalidation_0-rmse:1.33246\n", + "[61]\tvalidation_0-rmse:1.32767\n", + "[62]\tvalidation_0-rmse:1.33396\n", + "[63]\tvalidation_0-rmse:1.33038\n", + "[64]\tvalidation_0-rmse:1.32723\n", + "[65]\tvalidation_0-rmse:1.32445\n", + "[66]\tvalidation_0-rmse:1.32666\n", + "[67]\tvalidation_0-rmse:1.32470\n", + "[68]\tvalidation_0-rmse:1.32556\n", + "[69]\tvalidation_0-rmse:1.32294\n", + "[70]\tvalidation_0-rmse:1.32810\n", + "[71]\tvalidation_0-rmse:1.32573\n", + "[72]\tvalidation_0-rmse:1.33002\n", + "[73]\tvalidation_0-rmse:1.33069\n", + "[74]\tvalidation_0-rmse:1.33042\n", + "[75]\tvalidation_0-rmse:1.33060\n", + "[76]\tvalidation_0-rmse:1.33120\n", + "[77]\tvalidation_0-rmse:1.33269\n", + "[78]\tvalidation_0-rmse:1.32919\n", + "[79]\tvalidation_0-rmse:1.32429\n", + "[80]\tvalidation_0-rmse:1.32078\n", + "[81]\tvalidation_0-rmse:1.31978\n", + "[82]\tvalidation_0-rmse:1.30738\n", + "[83]\tvalidation_0-rmse:1.30778\n", + "[84]\tvalidation_0-rmse:1.31087\n", + "[85]\tvalidation_0-rmse:1.30729\n", + "[86]\tvalidation_0-rmse:1.30339\n", + "[87]\tvalidation_0-rmse:1.30681\n", + "[88]\tvalidation_0-rmse:1.30676\n", + "[89]\tvalidation_0-rmse:1.30625\n", + "[90]\tvalidation_0-rmse:1.30712\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "(1.5880184359288727,\n", + " 1.2601660350639803,\n", + " 0.9411508613501586,\n", + " 0.1599000195341022,\n", + " 0.8635770708193553)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n", + " 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "ae2c8a8a-1333-487e-9c2a-fee46bb9c934", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalidation_0-rmse:2.99087\n", + "[1]\tvalidation_0-rmse:2.66375\n", + "[2]\tvalidation_0-rmse:2.40788\n", + "[3]\tvalidation_0-rmse:2.21494\n", + "[4]\tvalidation_0-rmse:2.08983\n", + "[5]\tvalidation_0-rmse:2.00727\n", + "[6]\tvalidation_0-rmse:1.85305\n", + "[7]\tvalidation_0-rmse:1.78475\n", + "[8]\tvalidation_0-rmse:1.70319\n", + "[9]\tvalidation_0-rmse:1.65511\n", + "[10]\tvalidation_0-rmse:1.58430\n", + "[11]\tvalidation_0-rmse:1.55495\n", + "[12]\tvalidation_0-rmse:1.53684\n", + "[13]\tvalidation_0-rmse:1.48998\n", + "[14]\tvalidation_0-rmse:1.48409\n", + "[15]\tvalidation_0-rmse:1.42885\n", + "[16]\tvalidation_0-rmse:1.39821\n", + "[17]\tvalidation_0-rmse:1.36639\n", + "[18]\tvalidation_0-rmse:1.36579\n", + "[19]\tvalidation_0-rmse:1.36524\n", + "[20]\tvalidation_0-rmse:1.37165\n", + "[21]\tvalidation_0-rmse:1.37061\n", + "[22]\tvalidation_0-rmse:1.37467\n", + "[23]\tvalidation_0-rmse:1.36140\n", + "[24]\tvalidation_0-rmse:1.34469\n", + "[25]\tvalidation_0-rmse:1.34300\n", + "[26]\tvalidation_0-rmse:1.34283\n", + "[27]\tvalidation_0-rmse:1.33821\n", + "[28]\tvalidation_0-rmse:1.33426\n", + "[29]\tvalidation_0-rmse:1.35185\n", + "[30]\tvalidation_0-rmse:1.34884\n", + "[31]\tvalidation_0-rmse:1.35015\n", + "[32]\tvalidation_0-rmse:1.36925\n", + "[33]\tvalidation_0-rmse:1.36742\n", + "[34]\tvalidation_0-rmse:1.36185\n", + "[35]\tvalidation_0-rmse:1.38217\n", + "[36]\tvalidation_0-rmse:1.37213\n", + "[37]\tvalidation_0-rmse:1.35339\n", + "[38]\tvalidation_0-rmse:1.35314\n", + "[39]\tvalidation_0-rmse:1.35475\n", + "[40]\tvalidation_0-rmse:1.34997\n", + "[41]\tvalidation_0-rmse:1.33195\n", + "[42]\tvalidation_0-rmse:1.33518\n", + "[43]\tvalidation_0-rmse:1.33585\n", + "[44]\tvalidation_0-rmse:1.33598\n", + "[45]\tvalidation_0-rmse:1.34456\n", + "[46]\tvalidation_0-rmse:1.33476\n", + "[47]\tvalidation_0-rmse:1.35722\n", + "[48]\tvalidation_0-rmse:1.36327\n", + "[49]\tvalidation_0-rmse:1.36469\n", + "[50]\tvalidation_0-rmse:1.35926\n", + "[51]\tvalidation_0-rmse:1.36052\n", + "[52]\tvalidation_0-rmse:1.36092\n", + "[53]\tvalidation_0-rmse:1.36165\n", + "[54]\tvalidation_0-rmse:1.34672\n", + "[55]\tvalidation_0-rmse:1.35078\n", + "[56]\tvalidation_0-rmse:1.35164\n", + "[57]\tvalidation_0-rmse:1.35124\n", + "[58]\tvalidation_0-rmse:1.34962\n", + "[59]\tvalidation_0-rmse:1.35293\n", + "[60]\tvalidation_0-rmse:1.34839\n", + "[61]\tvalidation_0-rmse:1.33587\n", + "[62]\tvalidation_0-rmse:1.32711\n", + "[63]\tvalidation_0-rmse:1.32892\n", + "[64]\tvalidation_0-rmse:1.33008\n", + "[65]\tvalidation_0-rmse:1.33023\n", + "[66]\tvalidation_0-rmse:1.33082\n", + "[67]\tvalidation_0-rmse:1.33207\n", + "[68]\tvalidation_0-rmse:1.33262\n", + "[69]\tvalidation_0-rmse:1.33185\n", + "[70]\tvalidation_0-rmse:1.33114\n", + "[71]\tvalidation_0-rmse:1.33287\n", + "[72]\tvalidation_0-rmse:1.34803\n", + "[73]\tvalidation_0-rmse:1.35223\n", + "[74]\tvalidation_0-rmse:1.34266\n", + "[75]\tvalidation_0-rmse:1.34423\n", + "[76]\tvalidation_0-rmse:1.34351\n", + "[77]\tvalidation_0-rmse:1.33684\n", + "[78]\tvalidation_0-rmse:1.33450\n", + "[79]\tvalidation_0-rmse:1.35080\n", + "[80]\tvalidation_0-rmse:1.34307\n", + "[81]\tvalidation_0-rmse:1.33828\n", + "[82]\tvalidation_0-rmse:1.33786\n", + "[83]\tvalidation_0-rmse:1.33990\n", + "[84]\tvalidation_0-rmse:1.34383\n", + "[85]\tvalidation_0-rmse:1.34400\n", + "[86]\tvalidation_0-rmse:1.34246\n", + "[87]\tvalidation_0-rmse:1.34125\n", + "[88]\tvalidation_0-rmse:1.34723\n", + "[89]\tvalidation_0-rmse:1.34835\n", + "[90]\tvalidation_0-rmse:1.34827\n", + "[91]\tvalidation_0-rmse:1.34661\n", + "[92]\tvalidation_0-rmse:1.34880\n", + "[93]\tvalidation_0-rmse:1.34823\n", + "[94]\tvalidation_0-rmse:1.36337\n", + "[95]\tvalidation_0-rmse:1.36356\n", + "[96]\tvalidation_0-rmse:1.36404\n", + "[97]\tvalidation_0-rmse:1.36244\n", + "[98]\tvalidation_0-rmse:1.36225\n", + "[99]\tvalidation_0-rmse:1.36267\n", + "[100]\tvalidation_0-rmse:1.36269\n", + "[101]\tvalidation_0-rmse:1.36312\n", + "[102]\tvalidation_0-rmse:1.37098\n", + "[103]\tvalidation_0-rmse:1.37105\n", + "[104]\tvalidation_0-rmse:1.37414\n", + "[105]\tvalidation_0-rmse:1.37391\n", + "[106]\tvalidation_0-rmse:1.37318\n", + "[107]\tvalidation_0-rmse:1.36753\n", + "[108]\tvalidation_0-rmse:1.36538\n", + "[109]\tvalidation_0-rmse:1.36617\n", + "[110]\tvalidation_0-rmse:1.36542\n", + "[111]\tvalidation_0-rmse:1.36646\n", + "[112]\tvalidation_0-rmse:1.38288\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "(1.7612323261291336,\n", + " 1.3271142852554687,\n", + " 0.9281123720415317,\n", + " 0.14669084489239353,\n", + " 0.8486966728710329)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import xgboost as xgb\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Extract best parameters\n", + "best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n", + " 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n", + "\n", + "# Re-train the model with the best parameters and evaluate\n", + "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", + "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", + "gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", + "\n", + "# Predict and evaluate\n", + "y_pred = gbr.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "mae = mean_absolute_error(y_test, y_pred)\n", + "mape = mean_absolute_percentage_error(y_test, y_pred)\n", + "r2 = r2_score(y_test, y_pred)\n", + "\n", + "mse, rmse, mae, mape, r2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27880e3b-0149-4611-abda-9e06442704e4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}