{ "cells": [ { "cell_type": "code", "execution_count": 29, "id": "35268e29-6824-4e2d-983b-7e8259bd343f", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from hyperopt import hp, fmin, tpe, STATUS_OK, Trials\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 30, "id": "30a72889-e9da-45d3-bb03-cbe1b908a972", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AVFCCHNSOH/CO/CN/CRtHrdpTTar
025.5645.3254.6869.134.911.130.8124.020.8523070.2605960.01401130.010.00.26003.974958
115.2645.1154.8963.805.111.560.8828.650.9611290.3367950.02095830.010.00.26004.629865
29.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.04006.452928
39.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.04508.724672
49.9235.1664.8480.254.660.910.5613.620.6968220.1272900.00972020.020.06.050010.075968
\n", "
" ], "text/plain": [ " A V FC C H N S O H/C O/C \\\n", "0 25.56 45.32 54.68 69.13 4.91 1.13 0.81 24.02 0.852307 0.260596 \n", "1 15.26 45.11 54.89 63.80 5.11 1.56 0.88 28.65 0.961129 0.336795 \n", "2 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", "3 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", "4 9.92 35.16 64.84 80.25 4.66 0.91 0.56 13.62 0.696822 0.127290 \n", "\n", " N/C Rt Hr dp T Tar \n", "0 0.014011 30.0 10.0 0.2 600 3.974958 \n", "1 0.020958 30.0 10.0 0.2 600 4.629865 \n", "2 0.009720 20.0 20.0 6.0 400 6.452928 \n", "3 0.009720 20.0 20.0 6.0 450 8.724672 \n", "4 0.009720 20.0 20.0 6.0 500 10.075968 " ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_excel('/mnt/tanzk/mxx/rejie2.xlsx')\n", "data.head()" ] }, { "cell_type": "code", "execution_count": 31, "id": "652c035d-937c-4b6b-bbe3-b5ed9ea2ca92", "metadata": {}, "outputs": [], "source": [ "out_cols = ['Tar']" ] }, { "cell_type": "code", "execution_count": 32, "id": "9f00267b-0cf5-4cfc-97ef-8c15cd8d262d", "metadata": {}, "outputs": [], "source": [ "feature_cols = [x for x in data.columns if x not in out_cols]" ] }, { "cell_type": "code", "execution_count": 33, "id": "676d77b7-605b-4a2a-acc0-8d625324e1c2", "metadata": {}, "outputs": [], "source": [ "train_data = data.reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 34, "id": "7bc4dd3d-3285-4bea-bdd2-c55816674722", "metadata": {}, "outputs": [], "source": [ "import xgboost as xgb\n", "\n", "from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score" ] }, { "cell_type": "code", "execution_count": 35, "id": "52f19cb8-424e-47eb-a68d-f107a5d194bd", "metadata": {}, "outputs": [], "source": [ "# 定义超参数的搜索空间\n", "# space = {\n", "# 'eta': hp.loguniform('eta', -5, 0), # 学习率,搜索范围是 [1e-5, 1]\n", "# 'max_depth': hp.choice('max_depth', range(5, 30)), # 树的最大深度,搜索范围是 [1, 10]\n", "# 'min_child_weight': hp.uniform('min_child_weight', 0, 10), # 子节点最小的权重和\n", "# 'gamma': hp.loguniform('gamma', -5, 0), # 叶子节点分裂所需的最小损失减少\n", "# 'subsample': hp.uniform('subsample', 0.5, 1), # 训练集的采样率\n", "# 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 1), # 特征的采样率\n", "# }\n", "space = {\n", " 'eta': hp.loguniform('eta', -4, 0.5), # 学习率,搜索范围是 [1e-5, 1]\n", " 'max_depth': hp.choice('max_depth', range(5, 50)), # 树的最大深度,搜索范围是 [1, 10]\n", " 'min_child_weight': hp.uniform('min_child_weight', 0, 20), # 子节点最小的权重和\n", " 'gamma': hp.loguniform('gamma', -3, 0.7), # 叶子节点分裂所需的最小损失减少\n", " 'subsample': hp.uniform('subsample', 0.5, 0.9), # 训练集的采样率\n", " 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 0.9), # 特征的采样率\n", "}" ] }, { "cell_type": "code", "execution_count": 36, "id": "b4209d03-b608-4520-9e40-79e93b358c05", "metadata": {}, "outputs": [], "source": [ "# 划分训练集和测试集\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], \n", " train_data[out_cols], \n", " test_size=0.1, random_state=42)" ] }, { "cell_type": "code", "execution_count": 37, "id": "e0e3e9fd-f85d-4370-8b7d-129556624721", "metadata": {}, "outputs": [], "source": [ "# 定义目标函数,用于评估模型的性能\n", "def objective(params):\n", " # 创建决策树分类器实例\n", " gbr = xgb.XGBRegressor(**params)\n", " # 训练模型\n", " gbr.fit(X_train, y_train)\n", " # 使用模型进行预测\n", " y_pred = gbr.predict(X_test)\n", " mae = mean_absolute_error(y_test, y_pred)\n", " return {'loss': mae, 'status': STATUS_OK}" ] }, { "cell_type": "code", "execution_count": 38, "id": "cc20c5e9-072f-4f6c-b74e-6042c54408ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:06<00:00, 15.37trial/s, best loss: 0.927005504531766]\n" ] } ], "source": [ "# 创建 Trials 对象来记录搜索历史\n", "trials = Trials()\n", "\n", "# 使用 fmin 函数进行超参数优化\n", "best_params = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=100, trials=trials)" ] }, { "cell_type": "code", "execution_count": 40, "id": "498cff26-e385-48ac-82ed-5057ccbc370c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n" ] } ], "source": [ "print(best_params)\n" ] }, { "cell_type": "code", "execution_count": 41, "id": "5eeb907e-cdc1-434a-82d7-3b967287bbb3", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import KFold, train_test_split\n", "kf = KFold(n_splits=10, shuffle=True, random_state=42)" ] }, { "cell_type": "code", "execution_count": 42, "id": "e629cc0e-a9f8-44d7-9958-092a34a5e579", "metadata": {}, "outputs": [], "source": [ "num_boost_round = 1000" ] }, { "cell_type": "code", "execution_count": 43, "id": "8ddd834e-22b8-42dc-8387-b9983b52a71b", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 44, "id": "6e1074cd-88d1-41aa-8d6a-b11b3ae705e0", "metadata": {}, "outputs": [], "source": [ "plt.rcParams[\"font.sans-serif\"] = [\"WenQuanYi Micro Hei\"] # 设置字体\n", "plt.rcParams[\"axes.unicode_minus\"] = False # 正常显示负号" ] }, { "cell_type": "code", "execution_count": 45, "id": "70490cbb-f9cf-4b3c-b476-37f3721d50e8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MSE: 2.336, RMSE: 1.5284, MAE: 1.0563, MAPE: 17.41 %, R_2: 0.7993\n", "MSE: 2.1045, RMSE: 1.4507, MAE: 0.9808, MAPE: 14.64 %, R_2: 0.7211\n", "MSE: 1.7837, RMSE: 1.3356, MAE: 0.9841, MAPE: 12.31 %, R_2: 0.7866\n", "MSE: 3.7098, RMSE: 1.9261, MAE: 1.0768, MAPE: 21.11 %, R_2: 0.7009\n", "MSE: 1.5407, RMSE: 1.2412, MAE: 0.7912, MAPE: 13.12 %, R_2: 0.8225\n", "MSE: 2.5428, RMSE: 1.5946, MAE: 1.0797, MAPE: 15.42 %, R_2: 0.5668\n", "MSE: 1.6981, RMSE: 1.3031, MAE: 0.9132, MAPE: 13.13 %, R_2: 0.7728\n", "MSE: 1.6532, RMSE: 1.2858, MAE: 0.8018, MAPE: 9.68 %, R_2: 0.7663\n", "MSE: 2.2342, RMSE: 1.4947, MAE: 1.1072, MAPE: 16.14 %, R_2: 0.7479\n", "MSE: 1.6337, RMSE: 1.2782, MAE: 0.9285, MAPE: 12.27 %, R_2: 0.7293\n" ] } ], "source": [ "eva_list = list()\n", "eva_cols = ['MSE', 'RMSE', 'MAE', 'MAPE', 'R2']\n", "for (train_index, test_index) in kf.split(train_data):\n", " train = train_data.loc[train_index]\n", " valid = train_data.loc[test_index]\n", " X_train, Y_train = train[feature_cols], train[out_cols]\n", " X_valid, Y_valid = valid[feature_cols], valid[out_cols]\n", " dtrain = xgb.DMatrix(X_train, Y_train)\n", " dvalid = xgb.DMatrix(X_valid, Y_valid)\n", " watchlist = [(dvalid, 'eval')]\n", " gb_model = xgb.train(best_params, dtrain, num_boost_round, evals=watchlist,\n", " early_stopping_rounds=100, verbose_eval=False)\n", " y_pred = gb_model.predict(xgb.DMatrix(X_valid))\n", " y_true = Y_valid.values\n", " MSE = mean_squared_error(y_true, y_pred)\n", " RMSE = np.sqrt(mean_squared_error(y_true, y_pred))\n", " MAE = mean_absolute_error(y_true, y_pred)\n", " MAPE = mean_absolute_percentage_error(y_true, y_pred)\n", " R_2 = r2_score(y_true, y_pred)\n", " print('MSE:', round(MSE, 4), end=', ')\n", " print('RMSE:', round(RMSE, 4), end=', ')\n", " print('MAE:', round(MAE, 4), end=', ')\n", " print('MAPE:', round(MAPE*100, 2), '%', end=', ')\n", " print('R_2:', round(R_2, 4)) #R方为负就说明拟合效果比平均值差\n", " eva_list.append([MSE, RMSE, MAE, MAPE, R_2])" ] }, { "cell_type": "code", "execution_count": 46, "id": "10886d74-7543-49a6-b4f1-a31a3110fb99", "metadata": {}, "outputs": [], "source": [ "eva_df = pd.DataFrame.from_records(eva_list, columns=eva_cols)" ] }, { "cell_type": "code", "execution_count": 47, "id": "87718c37-27a4-44b4-b5ec-fbc6fdbc14fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MSE 2.123660\n", "RMSE 1.443832\n", "MAE 0.971950\n", "MAPE 0.145207\n", "R2 0.741365\n", "dtype: float64" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eva_df.mean()" ] }, { "cell_type": "code", "execution_count": 48, "id": "ef6508a2-ca8a-4fd1-8a8b-0e8a91d01d87", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1.861425600429374,\n", " 1.3643407200656932,\n", " 0.9369589057201204,\n", " 0.14918013653153422,\n", " 0.8400893043071764)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n", " 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "gbr.fit(X_train, y_train)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2\n" ] }, { "cell_type": "code", "execution_count": 49, "id": "f74ccd89-e1b0-44e8-913f-6b33e8115250", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1.78639546955746,\n", " 1.3365610609162082,\n", " 0.9437850044721565,\n", " 0.1579769923858331,\n", " 0.846534966396966)" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n", " 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "gbr.fit(X_train, y_train)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2" ] }, { "cell_type": "code", "execution_count": 50, "id": "8e223a2f-676b-4255-bd7e-b6ea8857a7e2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1.5784385533129173,\n", " 1.256359245324727,\n", " 0.8012875783803378,\n", " 0.12426418160928662,\n", " 0.8644000560052363)" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n", " 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "gbr.fit(X_train, y_train)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2\n" ] }, { "cell_type": "code", "execution_count": 52, "id": "92bf1d55-a20d-4337-b8cf-9f96cb53e71e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0]\tvalidation_0-rmse:2.94022\n", "[1]\tvalidation_0-rmse:2.32601\n", "[2]\tvalidation_0-rmse:2.15546\n", "[3]\tvalidation_0-rmse:1.97899\n", "[4]\tvalidation_0-rmse:1.76293\n", "[5]\tvalidation_0-rmse:1.66392\n", "[6]\tvalidation_0-rmse:1.56808\n", "[7]\tvalidation_0-rmse:1.57516\n", "[8]\tvalidation_0-rmse:1.55810\n", "[9]\tvalidation_0-rmse:1.49683\n", "[10]\tvalidation_0-rmse:1.42146\n", "[11]\tvalidation_0-rmse:1.38196\n", "[12]\tvalidation_0-rmse:1.33497\n", "[13]\tvalidation_0-rmse:1.29656\n", "[14]\tvalidation_0-rmse:1.31387\n", "[15]\tvalidation_0-rmse:1.29630\n", "[16]\tvalidation_0-rmse:1.32641\n", "[17]\tvalidation_0-rmse:1.34701\n", "[18]\tvalidation_0-rmse:1.34493\n", "[19]\tvalidation_0-rmse:1.34176\n", "[20]\tvalidation_0-rmse:1.35030\n", "[21]\tvalidation_0-rmse:1.35103\n", "[22]\tvalidation_0-rmse:1.36426\n", "[23]\tvalidation_0-rmse:1.38450\n", "[24]\tvalidation_0-rmse:1.35743\n", "[25]\tvalidation_0-rmse:1.34604\n", "[26]\tvalidation_0-rmse:1.36367\n", "[27]\tvalidation_0-rmse:1.34772\n", "[28]\tvalidation_0-rmse:1.35335\n", "[29]\tvalidation_0-rmse:1.37848\n", "[30]\tvalidation_0-rmse:1.37084\n", "[31]\tvalidation_0-rmse:1.36131\n", "[32]\tvalidation_0-rmse:1.37337\n", "[33]\tvalidation_0-rmse:1.36706\n", "[34]\tvalidation_0-rmse:1.35412\n", "[35]\tvalidation_0-rmse:1.35344\n", "[36]\tvalidation_0-rmse:1.33891\n", "[37]\tvalidation_0-rmse:1.31269\n", "[38]\tvalidation_0-rmse:1.31461\n", "[39]\tvalidation_0-rmse:1.31811\n", "[40]\tvalidation_0-rmse:1.32056\n", "[41]\tvalidation_0-rmse:1.32437\n", "[42]\tvalidation_0-rmse:1.28383\n", "[43]\tvalidation_0-rmse:1.28747\n", "[44]\tvalidation_0-rmse:1.28136\n", "[45]\tvalidation_0-rmse:1.30384\n", "[46]\tvalidation_0-rmse:1.31680\n", "[47]\tvalidation_0-rmse:1.32577\n", "[48]\tvalidation_0-rmse:1.32435\n", "[49]\tvalidation_0-rmse:1.34947\n", "[50]\tvalidation_0-rmse:1.33535\n", "[51]\tvalidation_0-rmse:1.33444\n", "[52]\tvalidation_0-rmse:1.32391\n", "[53]\tvalidation_0-rmse:1.32339\n", "[54]\tvalidation_0-rmse:1.31571\n", "[55]\tvalidation_0-rmse:1.32033\n", "[56]\tvalidation_0-rmse:1.32455\n", "[57]\tvalidation_0-rmse:1.33108\n", "[58]\tvalidation_0-rmse:1.33065\n", "[59]\tvalidation_0-rmse:1.33104\n", "[60]\tvalidation_0-rmse:1.32995\n", "[61]\tvalidation_0-rmse:1.33504\n", "[62]\tvalidation_0-rmse:1.33599\n", "[63]\tvalidation_0-rmse:1.33493\n", "[64]\tvalidation_0-rmse:1.33562\n", "[65]\tvalidation_0-rmse:1.33627\n", "[66]\tvalidation_0-rmse:1.33196\n", "[67]\tvalidation_0-rmse:1.32833\n", "[68]\tvalidation_0-rmse:1.33790\n", "[69]\tvalidation_0-rmse:1.33838\n", "[70]\tvalidation_0-rmse:1.34104\n", "[71]\tvalidation_0-rmse:1.34476\n", "[72]\tvalidation_0-rmse:1.34435\n", "[73]\tvalidation_0-rmse:1.33861\n", "[74]\tvalidation_0-rmse:1.33732\n", "[75]\tvalidation_0-rmse:1.34115\n", "[76]\tvalidation_0-rmse:1.33682\n", "[77]\tvalidation_0-rmse:1.33896\n", "[78]\tvalidation_0-rmse:1.34560\n", "[79]\tvalidation_0-rmse:1.34630\n", "[80]\tvalidation_0-rmse:1.34347\n", "[81]\tvalidation_0-rmse:1.34940\n", "[82]\tvalidation_0-rmse:1.34326\n", "[83]\tvalidation_0-rmse:1.34504\n", "[84]\tvalidation_0-rmse:1.33780\n", "[85]\tvalidation_0-rmse:1.34329\n", "[86]\tvalidation_0-rmse:1.34396\n", "[87]\tvalidation_0-rmse:1.34863\n", "[88]\tvalidation_0-rmse:1.34721\n", "[89]\tvalidation_0-rmse:1.35453\n", "[90]\tvalidation_0-rmse:1.35999\n", "[91]\tvalidation_0-rmse:1.34558\n", "[92]\tvalidation_0-rmse:1.35019\n", "[93]\tvalidation_0-rmse:1.35153\n", "[94]\tvalidation_0-rmse:1.34330\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1.6418860457242266,\n", " 1.2813610130342763,\n", " 0.8241156918341648,\n", " 0.13793177033015197,\n", " 0.8589494311459256)" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.8997619766112827, 'eta': 0.48900376453173927, 'gamma': 0.09568323449358279, \n", " 'max_depth': 29, 'min_child_weight': 2.0607020689885673, 'subsample': 0.5621662915587151}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "# gbr.fit(X_train, y_train)\n", "gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2" ] }, { "cell_type": "code", "execution_count": 53, "id": "097fba6f-a6af-4529-9d09-589ccf84b991", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0]\tvalidation_0-rmse:3.03131\n", "[1]\tvalidation_0-rmse:2.62574\n", "[2]\tvalidation_0-rmse:2.41405\n", "[3]\tvalidation_0-rmse:2.21943\n", "[4]\tvalidation_0-rmse:2.06257\n", "[5]\tvalidation_0-rmse:1.96808\n", "[6]\tvalidation_0-rmse:1.73678\n", "[7]\tvalidation_0-rmse:1.70331\n", "[8]\tvalidation_0-rmse:1.64866\n", "[9]\tvalidation_0-rmse:1.62014\n", "[10]\tvalidation_0-rmse:1.55086\n", "[11]\tvalidation_0-rmse:1.50959\n", "[12]\tvalidation_0-rmse:1.47113\n", "[13]\tvalidation_0-rmse:1.44401\n", "[14]\tvalidation_0-rmse:1.42153\n", "[15]\tvalidation_0-rmse:1.39003\n", "[16]\tvalidation_0-rmse:1.35574\n", "[17]\tvalidation_0-rmse:1.32175\n", "[18]\tvalidation_0-rmse:1.31396\n", "[19]\tvalidation_0-rmse:1.33739\n", "[20]\tvalidation_0-rmse:1.33013\n", "[21]\tvalidation_0-rmse:1.32880\n", "[22]\tvalidation_0-rmse:1.33059\n", "[23]\tvalidation_0-rmse:1.34923\n", "[24]\tvalidation_0-rmse:1.31799\n", "[25]\tvalidation_0-rmse:1.31411\n", "[26]\tvalidation_0-rmse:1.30925\n", "[27]\tvalidation_0-rmse:1.29766\n", "[28]\tvalidation_0-rmse:1.30442\n", "[29]\tvalidation_0-rmse:1.30966\n", "[30]\tvalidation_0-rmse:1.28975\n", "[31]\tvalidation_0-rmse:1.28810\n", "[32]\tvalidation_0-rmse:1.28587\n", "[33]\tvalidation_0-rmse:1.28566\n", "[34]\tvalidation_0-rmse:1.26751\n", "[35]\tvalidation_0-rmse:1.28569\n", "[36]\tvalidation_0-rmse:1.27332\n", "[37]\tvalidation_0-rmse:1.25467\n", "[38]\tvalidation_0-rmse:1.25508\n", "[39]\tvalidation_0-rmse:1.24264\n", "[40]\tvalidation_0-rmse:1.24019\n", "[41]\tvalidation_0-rmse:1.24022\n", "[42]\tvalidation_0-rmse:1.24175\n", "[43]\tvalidation_0-rmse:1.23122\n", "[44]\tvalidation_0-rmse:1.22942\n", "[45]\tvalidation_0-rmse:1.24364\n", "[46]\tvalidation_0-rmse:1.22211\n", "[47]\tvalidation_0-rmse:1.24425\n", "[48]\tvalidation_0-rmse:1.23057\n", "[49]\tvalidation_0-rmse:1.24111\n", "[50]\tvalidation_0-rmse:1.23988\n", "[51]\tvalidation_0-rmse:1.23965\n", "[52]\tvalidation_0-rmse:1.23850\n", "[53]\tvalidation_0-rmse:1.23933\n", "[54]\tvalidation_0-rmse:1.23724\n", "[55]\tvalidation_0-rmse:1.23362\n", "[56]\tvalidation_0-rmse:1.23369\n", "[57]\tvalidation_0-rmse:1.23845\n", "[58]\tvalidation_0-rmse:1.23947\n", "[59]\tvalidation_0-rmse:1.23491\n", "[60]\tvalidation_0-rmse:1.23482\n", "[61]\tvalidation_0-rmse:1.23441\n", "[62]\tvalidation_0-rmse:1.23183\n", "[63]\tvalidation_0-rmse:1.23244\n", "[64]\tvalidation_0-rmse:1.23290\n", "[65]\tvalidation_0-rmse:1.23486\n", "[66]\tvalidation_0-rmse:1.23535\n", "[67]\tvalidation_0-rmse:1.23568\n", "[68]\tvalidation_0-rmse:1.23613\n", "[69]\tvalidation_0-rmse:1.23555\n", "[70]\tvalidation_0-rmse:1.23506\n", "[71]\tvalidation_0-rmse:1.23512\n", "[72]\tvalidation_0-rmse:1.23276\n", "[73]\tvalidation_0-rmse:1.23287\n", "[74]\tvalidation_0-rmse:1.23235\n", "[75]\tvalidation_0-rmse:1.23922\n", "[76]\tvalidation_0-rmse:1.23873\n", "[77]\tvalidation_0-rmse:1.22865\n", "[78]\tvalidation_0-rmse:1.23035\n", "[79]\tvalidation_0-rmse:1.24439\n", "[80]\tvalidation_0-rmse:1.23953\n", "[81]\tvalidation_0-rmse:1.24745\n", "[82]\tvalidation_0-rmse:1.23214\n", "[83]\tvalidation_0-rmse:1.23968\n", "[84]\tvalidation_0-rmse:1.24333\n", "[85]\tvalidation_0-rmse:1.24434\n", "[86]\tvalidation_0-rmse:1.25339\n", "[87]\tvalidation_0-rmse:1.25396\n", "[88]\tvalidation_0-rmse:1.25383\n", "[89]\tvalidation_0-rmse:1.25276\n", "[90]\tvalidation_0-rmse:1.25349\n", "[91]\tvalidation_0-rmse:1.25226\n", "[92]\tvalidation_0-rmse:1.25328\n", "[93]\tvalidation_0-rmse:1.25284\n", "[94]\tvalidation_0-rmse:1.25222\n", "[95]\tvalidation_0-rmse:1.25212\n", "[96]\tvalidation_0-rmse:1.25266\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1.4935549430378676,\n", " 1.2221108554619207,\n", " 0.8199210432780543,\n", " 0.1256990216917289,\n", " 0.8716922073374574)" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.7756001484050402, 'eta': 0.31927224345318256, 'gamma': 0.5049174573053737, \n", " 'max_depth': 42, 'min_child_weight': 6.449650970113468, 'subsample': 0.7873416063207794}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "# gbr.fit(X_train, y_train)\n", "gbr.fit(X_train, y_train, early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2" ] }, { "cell_type": "code", "execution_count": 54, "id": "9479b1af-a005-4abb-9664-9d3d4837133f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0]\tvalidation_0-rmse:2.94796\n", "[1]\tvalidation_0-rmse:2.47933\n", "[2]\tvalidation_0-rmse:2.32980\n", "[3]\tvalidation_0-rmse:2.10618\n", "[4]\tvalidation_0-rmse:1.97724\n", "[5]\tvalidation_0-rmse:1.95411\n", "[6]\tvalidation_0-rmse:1.84506\n", "[7]\tvalidation_0-rmse:1.84115\n", "[8]\tvalidation_0-rmse:1.82857\n", "[9]\tvalidation_0-rmse:1.76881\n", "[10]\tvalidation_0-rmse:1.66718\n", "[11]\tvalidation_0-rmse:1.65162\n", "[12]\tvalidation_0-rmse:1.61560\n", "[13]\tvalidation_0-rmse:1.57779\n", "[14]\tvalidation_0-rmse:1.54531\n", "[15]\tvalidation_0-rmse:1.51655\n", "[16]\tvalidation_0-rmse:1.49713\n", "[17]\tvalidation_0-rmse:1.46052\n", "[18]\tvalidation_0-rmse:1.42253\n", "[19]\tvalidation_0-rmse:1.40742\n", "[20]\tvalidation_0-rmse:1.36721\n", "[21]\tvalidation_0-rmse:1.35093\n", "[22]\tvalidation_0-rmse:1.35702\n", "[23]\tvalidation_0-rmse:1.33818\n", "[24]\tvalidation_0-rmse:1.32813\n", "[25]\tvalidation_0-rmse:1.35230\n", "[26]\tvalidation_0-rmse:1.34903\n", "[27]\tvalidation_0-rmse:1.34344\n", "[28]\tvalidation_0-rmse:1.34388\n", "[29]\tvalidation_0-rmse:1.35948\n", "[30]\tvalidation_0-rmse:1.33423\n", "[31]\tvalidation_0-rmse:1.35954\n", "[32]\tvalidation_0-rmse:1.35215\n", "[33]\tvalidation_0-rmse:1.33783\n", "[34]\tvalidation_0-rmse:1.31944\n", "[35]\tvalidation_0-rmse:1.32828\n", "[36]\tvalidation_0-rmse:1.29862\n", "[37]\tvalidation_0-rmse:1.27408\n", "[38]\tvalidation_0-rmse:1.28541\n", "[39]\tvalidation_0-rmse:1.27269\n", "[40]\tvalidation_0-rmse:1.26776\n", "[41]\tvalidation_0-rmse:1.26017\n", "[42]\tvalidation_0-rmse:1.29093\n", "[43]\tvalidation_0-rmse:1.28878\n", "[44]\tvalidation_0-rmse:1.29039\n", "[45]\tvalidation_0-rmse:1.28695\n", "[46]\tvalidation_0-rmse:1.28674\n", "[47]\tvalidation_0-rmse:1.29803\n", "[48]\tvalidation_0-rmse:1.28585\n", "[49]\tvalidation_0-rmse:1.29421\n", "[50]\tvalidation_0-rmse:1.29601\n", "[51]\tvalidation_0-rmse:1.31381\n", "[52]\tvalidation_0-rmse:1.32534\n", "[53]\tvalidation_0-rmse:1.32683\n", "[54]\tvalidation_0-rmse:1.31673\n", "[55]\tvalidation_0-rmse:1.32776\n", "[56]\tvalidation_0-rmse:1.32463\n", "[57]\tvalidation_0-rmse:1.32801\n", "[58]\tvalidation_0-rmse:1.32713\n", "[59]\tvalidation_0-rmse:1.33055\n", "[60]\tvalidation_0-rmse:1.33246\n", "[61]\tvalidation_0-rmse:1.32767\n", "[62]\tvalidation_0-rmse:1.33396\n", "[63]\tvalidation_0-rmse:1.33038\n", "[64]\tvalidation_0-rmse:1.32723\n", "[65]\tvalidation_0-rmse:1.32445\n", "[66]\tvalidation_0-rmse:1.32666\n", "[67]\tvalidation_0-rmse:1.32470\n", "[68]\tvalidation_0-rmse:1.32556\n", "[69]\tvalidation_0-rmse:1.32294\n", "[70]\tvalidation_0-rmse:1.32810\n", "[71]\tvalidation_0-rmse:1.32573\n", "[72]\tvalidation_0-rmse:1.33002\n", "[73]\tvalidation_0-rmse:1.33069\n", "[74]\tvalidation_0-rmse:1.33042\n", "[75]\tvalidation_0-rmse:1.33060\n", "[76]\tvalidation_0-rmse:1.33120\n", "[77]\tvalidation_0-rmse:1.33269\n", "[78]\tvalidation_0-rmse:1.32919\n", "[79]\tvalidation_0-rmse:1.32429\n", "[80]\tvalidation_0-rmse:1.32078\n", "[81]\tvalidation_0-rmse:1.31978\n", "[82]\tvalidation_0-rmse:1.30738\n", "[83]\tvalidation_0-rmse:1.30778\n", "[84]\tvalidation_0-rmse:1.31087\n", "[85]\tvalidation_0-rmse:1.30729\n", "[86]\tvalidation_0-rmse:1.30339\n", "[87]\tvalidation_0-rmse:1.30681\n", "[88]\tvalidation_0-rmse:1.30676\n", "[89]\tvalidation_0-rmse:1.30625\n", "[90]\tvalidation_0-rmse:1.30712\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1.5880184359288727,\n", " 1.2601660350639803,\n", " 0.9411508613501586,\n", " 0.1599000195341022,\n", " 0.8635770708193553)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.8182688124328266, 'eta': 0.39669872117044186, 'gamma': 0.67893237292294242, \n", " 'max_depth': 23, 'min_child_weight': 7.274037788798998, 'subsample': 0.6957233806783182}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2" ] }, { "cell_type": "code", "execution_count": 55, "id": "ae2c8a8a-1333-487e-9c2a-fee46bb9c934", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0]\tvalidation_0-rmse:2.99087\n", "[1]\tvalidation_0-rmse:2.66375\n", "[2]\tvalidation_0-rmse:2.40788\n", "[3]\tvalidation_0-rmse:2.21494\n", "[4]\tvalidation_0-rmse:2.08983\n", "[5]\tvalidation_0-rmse:2.00727\n", "[6]\tvalidation_0-rmse:1.85305\n", "[7]\tvalidation_0-rmse:1.78475\n", "[8]\tvalidation_0-rmse:1.70319\n", "[9]\tvalidation_0-rmse:1.65511\n", "[10]\tvalidation_0-rmse:1.58430\n", "[11]\tvalidation_0-rmse:1.55495\n", "[12]\tvalidation_0-rmse:1.53684\n", "[13]\tvalidation_0-rmse:1.48998\n", "[14]\tvalidation_0-rmse:1.48409\n", "[15]\tvalidation_0-rmse:1.42885\n", "[16]\tvalidation_0-rmse:1.39821\n", "[17]\tvalidation_0-rmse:1.36639\n", "[18]\tvalidation_0-rmse:1.36579\n", "[19]\tvalidation_0-rmse:1.36524\n", "[20]\tvalidation_0-rmse:1.37165\n", "[21]\tvalidation_0-rmse:1.37061\n", "[22]\tvalidation_0-rmse:1.37467\n", "[23]\tvalidation_0-rmse:1.36140\n", "[24]\tvalidation_0-rmse:1.34469\n", "[25]\tvalidation_0-rmse:1.34300\n", "[26]\tvalidation_0-rmse:1.34283\n", "[27]\tvalidation_0-rmse:1.33821\n", "[28]\tvalidation_0-rmse:1.33426\n", "[29]\tvalidation_0-rmse:1.35185\n", "[30]\tvalidation_0-rmse:1.34884\n", "[31]\tvalidation_0-rmse:1.35015\n", "[32]\tvalidation_0-rmse:1.36925\n", "[33]\tvalidation_0-rmse:1.36742\n", "[34]\tvalidation_0-rmse:1.36185\n", "[35]\tvalidation_0-rmse:1.38217\n", "[36]\tvalidation_0-rmse:1.37213\n", "[37]\tvalidation_0-rmse:1.35339\n", "[38]\tvalidation_0-rmse:1.35314\n", "[39]\tvalidation_0-rmse:1.35475\n", "[40]\tvalidation_0-rmse:1.34997\n", "[41]\tvalidation_0-rmse:1.33195\n", "[42]\tvalidation_0-rmse:1.33518\n", "[43]\tvalidation_0-rmse:1.33585\n", "[44]\tvalidation_0-rmse:1.33598\n", "[45]\tvalidation_0-rmse:1.34456\n", "[46]\tvalidation_0-rmse:1.33476\n", "[47]\tvalidation_0-rmse:1.35722\n", "[48]\tvalidation_0-rmse:1.36327\n", "[49]\tvalidation_0-rmse:1.36469\n", "[50]\tvalidation_0-rmse:1.35926\n", "[51]\tvalidation_0-rmse:1.36052\n", "[52]\tvalidation_0-rmse:1.36092\n", "[53]\tvalidation_0-rmse:1.36165\n", "[54]\tvalidation_0-rmse:1.34672\n", "[55]\tvalidation_0-rmse:1.35078\n", "[56]\tvalidation_0-rmse:1.35164\n", "[57]\tvalidation_0-rmse:1.35124\n", "[58]\tvalidation_0-rmse:1.34962\n", "[59]\tvalidation_0-rmse:1.35293\n", "[60]\tvalidation_0-rmse:1.34839\n", "[61]\tvalidation_0-rmse:1.33587\n", "[62]\tvalidation_0-rmse:1.32711\n", "[63]\tvalidation_0-rmse:1.32892\n", "[64]\tvalidation_0-rmse:1.33008\n", "[65]\tvalidation_0-rmse:1.33023\n", "[66]\tvalidation_0-rmse:1.33082\n", "[67]\tvalidation_0-rmse:1.33207\n", "[68]\tvalidation_0-rmse:1.33262\n", "[69]\tvalidation_0-rmse:1.33185\n", "[70]\tvalidation_0-rmse:1.33114\n", "[71]\tvalidation_0-rmse:1.33287\n", "[72]\tvalidation_0-rmse:1.34803\n", "[73]\tvalidation_0-rmse:1.35223\n", "[74]\tvalidation_0-rmse:1.34266\n", "[75]\tvalidation_0-rmse:1.34423\n", "[76]\tvalidation_0-rmse:1.34351\n", "[77]\tvalidation_0-rmse:1.33684\n", "[78]\tvalidation_0-rmse:1.33450\n", "[79]\tvalidation_0-rmse:1.35080\n", "[80]\tvalidation_0-rmse:1.34307\n", "[81]\tvalidation_0-rmse:1.33828\n", "[82]\tvalidation_0-rmse:1.33786\n", "[83]\tvalidation_0-rmse:1.33990\n", "[84]\tvalidation_0-rmse:1.34383\n", "[85]\tvalidation_0-rmse:1.34400\n", "[86]\tvalidation_0-rmse:1.34246\n", "[87]\tvalidation_0-rmse:1.34125\n", "[88]\tvalidation_0-rmse:1.34723\n", "[89]\tvalidation_0-rmse:1.34835\n", "[90]\tvalidation_0-rmse:1.34827\n", "[91]\tvalidation_0-rmse:1.34661\n", "[92]\tvalidation_0-rmse:1.34880\n", "[93]\tvalidation_0-rmse:1.34823\n", "[94]\tvalidation_0-rmse:1.36337\n", "[95]\tvalidation_0-rmse:1.36356\n", "[96]\tvalidation_0-rmse:1.36404\n", "[97]\tvalidation_0-rmse:1.36244\n", "[98]\tvalidation_0-rmse:1.36225\n", "[99]\tvalidation_0-rmse:1.36267\n", "[100]\tvalidation_0-rmse:1.36269\n", "[101]\tvalidation_0-rmse:1.36312\n", "[102]\tvalidation_0-rmse:1.37098\n", "[103]\tvalidation_0-rmse:1.37105\n", "[104]\tvalidation_0-rmse:1.37414\n", "[105]\tvalidation_0-rmse:1.37391\n", "[106]\tvalidation_0-rmse:1.37318\n", "[107]\tvalidation_0-rmse:1.36753\n", "[108]\tvalidation_0-rmse:1.36538\n", "[109]\tvalidation_0-rmse:1.36617\n", "[110]\tvalidation_0-rmse:1.36542\n", "[111]\tvalidation_0-rmse:1.36646\n", "[112]\tvalidation_0-rmse:1.38288\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/envs/python38/lib/python3.8/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1.7612323261291336,\n", " 1.3271142852554687,\n", " 0.9281123720415317,\n", " 0.14669084489239353,\n", " 0.8486966728710329)" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Extract best parameters\n", "best_params = {'colsample_bytree': 0.7591178772740766, 'eta': 0.29006097172943296, 'gamma': 0.32020608660889016, \n", " 'max_depth': 21, 'min_child_weight': 5.912424330716954, 'subsample': 0.8011115810485918}\n", "\n", "# Re-train the model with the best parameters and evaluate\n", "X_train, X_test, y_train, y_test = train_test_split(train_data[feature_cols], train_data[out_cols], test_size=0.1, random_state=42)\n", "gbr = xgb.XGBRegressor(**best_params, n_estimators=1500)\n", "gbr.fit(X_train, y_train,early_stopping_rounds=50, eval_set=[(X_test, y_test)], verbose=True)\n", "\n", "# Predict and evaluate\n", "y_pred = gbr.predict(X_test)\n", "mse = mean_squared_error(y_test, y_pred)\n", "rmse = np.sqrt(mse)\n", "mae = mean_absolute_error(y_test, y_pred)\n", "mape = mean_absolute_percentage_error(y_test, y_pred)\n", "r2 = r2_score(y_test, y_pred)\n", "\n", "mse, rmse, mae, mape, r2" ] }, { "cell_type": "code", "execution_count": null, "id": "27880e3b-0149-4611-abda-9e06442704e4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }