22-T67/catboost.ipynb

602 lines
19 KiB
Plaintext
Raw Normal View History

2023-03-30 10:25:44 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import catboost\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from catboost import CatBoostRegressor"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PM2.5</th>\n",
" <th>PM10</th>\n",
" <th>SO2</th>\n",
" <th>NO2</th>\n",
" <th>O3</th>\n",
" <th>O3_8h</th>\n",
" <th>CO</th>\n",
" <th>wd</th>\n",
" <th>ws</th>\n",
" <th>air_temp</th>\n",
" <th>...</th>\n",
" <th>PM2.5_transportation</th>\n",
" <th>PM2.5_resdient</th>\n",
" <th>PM2.5_power</th>\n",
" <th>pre_PM2.5</th>\n",
" <th>pre_PM10</th>\n",
" <th>pre_SO2</th>\n",
" <th>pre_NO2</th>\n",
" <th>pre_O3</th>\n",
" <th>pre_O3_8h</th>\n",
" <th>pre_CO</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>4.744932</td>\n",
" <td>5.176150</td>\n",
" <td>5.723585</td>\n",
" <td>3.663562</td>\n",
" <td>2.197225</td>\n",
" <td>2.302585</td>\n",
" <td>1.515127</td>\n",
" <td>58.0</td>\n",
" <td>0.7</td>\n",
" <td>-11.1</td>\n",
" <td>...</td>\n",
" <td>0.081248</td>\n",
" <td>0.827110</td>\n",
" <td>0.418587</td>\n",
" <td>136.0</td>\n",
" <td>214.0</td>\n",
" <td>317.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>9.0</td>\n",
" <td>3.71</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.584967</td>\n",
" <td>5.043425</td>\n",
" <td>5.726848</td>\n",
" <td>3.637586</td>\n",
" <td>2.079442</td>\n",
" <td>2.197225</td>\n",
" <td>1.506297</td>\n",
" <td>185.0</td>\n",
" <td>0.5</td>\n",
" <td>-11.7</td>\n",
" <td>...</td>\n",
" <td>0.088313</td>\n",
" <td>0.827110</td>\n",
" <td>0.412773</td>\n",
" <td>114.0</td>\n",
" <td>176.0</td>\n",
" <td>305.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>9.0</td>\n",
" <td>3.55</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.477337</td>\n",
" <td>4.955827</td>\n",
" <td>5.758902</td>\n",
" <td>3.663562</td>\n",
" <td>2.079442</td>\n",
" <td>2.197225</td>\n",
" <td>1.515127</td>\n",
" <td>0.0</td>\n",
" <td>0.2</td>\n",
" <td>-12.7</td>\n",
" <td>...</td>\n",
" <td>0.091256</td>\n",
" <td>0.827110</td>\n",
" <td>0.424400</td>\n",
" <td>97.0</td>\n",
" <td>154.0</td>\n",
" <td>306.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>3.51</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.454347</td>\n",
" <td>4.941642</td>\n",
" <td>5.680173</td>\n",
" <td>3.637586</td>\n",
" <td>2.079442</td>\n",
" <td>2.197225</td>\n",
" <td>1.530395</td>\n",
" <td>199.0</td>\n",
" <td>1.4</td>\n",
" <td>-10.9</td>\n",
" <td>...</td>\n",
" <td>0.092434</td>\n",
" <td>1.746121</td>\n",
" <td>0.459282</td>\n",
" <td>87.0</td>\n",
" <td>141.0</td>\n",
" <td>316.0</td>\n",
" <td>38.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>3.55</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4.672829</td>\n",
" <td>5.123964</td>\n",
" <td>5.758902</td>\n",
" <td>3.637586</td>\n",
" <td>2.197225</td>\n",
" <td>2.197225</td>\n",
" <td>1.605430</td>\n",
" <td>359.0</td>\n",
" <td>1.2</td>\n",
" <td>-12.3</td>\n",
" <td>...</td>\n",
" <td>0.170738</td>\n",
" <td>3.446292</td>\n",
" <td>0.514513</td>\n",
" <td>85.0</td>\n",
" <td>139.0</td>\n",
" <td>292.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>3.62</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 49 columns</p>\n",
"</div>"
],
"text/plain": [
" PM2.5 PM10 SO2 NO2 O3 O3_8h CO \\\n",
"0 4.744932 5.176150 5.723585 3.663562 2.197225 2.302585 1.515127 \n",
"1 4.584967 5.043425 5.726848 3.637586 2.079442 2.197225 1.506297 \n",
"2 4.477337 4.955827 5.758902 3.663562 2.079442 2.197225 1.515127 \n",
"3 4.454347 4.941642 5.680173 3.637586 2.079442 2.197225 1.530395 \n",
"4 4.672829 5.123964 5.758902 3.637586 2.197225 2.197225 1.605430 \n",
"\n",
" wd ws air_temp ... PM2.5_transportation PM2.5_resdient \\\n",
"0 58.0 0.7 -11.1 ... 0.081248 0.827110 \n",
"1 185.0 0.5 -11.7 ... 0.088313 0.827110 \n",
"2 0.0 0.2 -12.7 ... 0.091256 0.827110 \n",
"3 199.0 1.4 -10.9 ... 0.092434 1.746121 \n",
"4 359.0 1.2 -12.3 ... 0.170738 3.446292 \n",
"\n",
" PM2.5_power pre_PM2.5 pre_PM10 pre_SO2 pre_NO2 pre_O3 pre_O3_8h \\\n",
"0 0.418587 136.0 214.0 317.0 38.0 8.0 9.0 \n",
"1 0.412773 114.0 176.0 305.0 38.0 8.0 9.0 \n",
"2 0.424400 97.0 154.0 306.0 37.0 7.0 8.0 \n",
"3 0.459282 87.0 141.0 316.0 38.0 7.0 8.0 \n",
"4 0.514513 85.0 139.0 292.0 37.0 7.0 8.0 \n",
"\n",
" pre_CO \n",
"0 3.71 \n",
"1 3.55 \n",
"2 3.51 \n",
"3 3.55 \n",
"4 3.62 \n",
"\n",
"[5 rows x 49 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('./data/train_data_mod.csv')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'O3_8h', 'CO'], dtype='object')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature_cols = data.columns[7:]\n",
"out_cols = data.columns[:7]\n",
"out_cols"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"\n",
"train_X, test_X, train_y, test_y = train_test_split(data[feature_cols], data[out_cols], test_size=0.2,\n",
" random_state=42)\n",
"#准备参数\n",
"other_params = {'learning_rate': 0.01, 'n_estimators': 300, 'max_depth': 5, 'min_child_weight': 1, 'seed': 0,\n",
" 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0, 'reg_alpha': 0, 'reg_lambda': 1}\n",
"\n",
"params_gbm = {\n",
" 'task': 'train',\n",
" 'boosting_type': 'gbdt', # 设置提升类型\n",
" 'objective': 'l1', # 目标函数\n",
" 'metric': 'rmse', # 评估函数\n",
" 'max_depth': 10,\n",
" 'num_leaves': 20, # 叶子节点数\n",
" 'learning_rate': 0.09, # 学习速率\n",
" 'feature_fraction': 0.9, # 建树的特征选择比例\n",
" 'bagging_fraction': 0.9, # 建树的样本采样比例\n",
" 'bagging_freq': 10, # k 意味着每 k 次迭代执行bagging\n",
" 'verbose': -1 # <0 显示致命的, =0 显示错误 (警告), >0 显示信息\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import lightgbm as lgb\n",
"from sklearn.multioutput import MultiOutputRegressor\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"base_model = lgb.LGBMRegressor(**params_gbm)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"base_cat = CatBoostRegressor(iterations=1000, learning_rate=0.0005, depth=10, loss_function='RMSE', eval_metric='RMSE', random_seed=99, od_type='Iter', od_wait=50, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"multioutputregressor = MultiOutputRegressor(base_cat).fit(train_X, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"rst = multioutputregressor.predict(test_X)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PM2.5</th>\n",
" <th>PM10</th>\n",
" <th>SO2</th>\n",
" <th>NO2</th>\n",
" <th>O3</th>\n",
" <th>O3_8h</th>\n",
" <th>CO</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.653185</td>\n",
" <td>4.700200</td>\n",
" <td>2.722381</td>\n",
" <td>3.261589</td>\n",
" <td>3.836444</td>\n",
" <td>3.857181</td>\n",
" <td>0.615219</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.323887</td>\n",
" <td>4.923196</td>\n",
" <td>3.198502</td>\n",
" <td>4.016752</td>\n",
" <td>3.166474</td>\n",
" <td>3.591639</td>\n",
" <td>0.824886</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3.660165</td>\n",
" <td>4.662362</td>\n",
" <td>3.136948</td>\n",
" <td>3.513742</td>\n",
" <td>3.763910</td>\n",
" <td>3.671770</td>\n",
" <td>0.631061</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.728112</td>\n",
" <td>4.645958</td>\n",
" <td>3.514411</td>\n",
" <td>3.718547</td>\n",
" <td>3.199907</td>\n",
" <td>3.291750</td>\n",
" <td>0.862777</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4.189668</td>\n",
" <td>4.743439</td>\n",
" <td>3.445615</td>\n",
" <td>3.674801</td>\n",
" <td>3.949052</td>\n",
" <td>3.695285</td>\n",
" <td>0.797655</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8902</th>\n",
" <td>4.283530</td>\n",
" <td>4.995899</td>\n",
" <td>4.019444</td>\n",
" <td>3.961054</td>\n",
" <td>3.663294</td>\n",
" <td>3.314231</td>\n",
" <td>0.916339</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8903</th>\n",
" <td>3.674866</td>\n",
" <td>4.606504</td>\n",
" <td>3.470283</td>\n",
" <td>3.307148</td>\n",
" <td>3.608739</td>\n",
" <td>3.626463</td>\n",
" <td>0.860751</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8904</th>\n",
" <td>3.704409</td>\n",
" <td>4.350563</td>\n",
" <td>3.757374</td>\n",
" <td>3.636318</td>\n",
" <td>3.601366</td>\n",
" <td>3.539929</td>\n",
" <td>0.862651</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8905</th>\n",
" <td>3.724967</td>\n",
" <td>4.673218</td>\n",
" <td>3.218182</td>\n",
" <td>3.765976</td>\n",
" <td>3.386151</td>\n",
" <td>2.954136</td>\n",
" <td>0.730686</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8906</th>\n",
" <td>3.419836</td>\n",
" <td>4.188060</td>\n",
" <td>3.019416</td>\n",
" <td>3.307404</td>\n",
" <td>3.861704</td>\n",
" <td>3.746484</td>\n",
" <td>0.702885</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>8907 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" PM2.5 PM10 SO2 NO2 O3 O3_8h CO\n",
"0 3.653185 4.700200 2.722381 3.261589 3.836444 3.857181 0.615219\n",
"1 4.323887 4.923196 3.198502 4.016752 3.166474 3.591639 0.824886\n",
"2 3.660165 4.662362 3.136948 3.513742 3.763910 3.671770 0.631061\n",
"3 3.728112 4.645958 3.514411 3.718547 3.199907 3.291750 0.862777\n",
"4 4.189668 4.743439 3.445615 3.674801 3.949052 3.695285 0.797655\n",
"... ... ... ... ... ... ... ...\n",
"8902 4.283530 4.995899 4.019444 3.961054 3.663294 3.314231 0.916339\n",
"8903 3.674866 4.606504 3.470283 3.307148 3.608739 3.626463 0.860751\n",
"8904 3.704409 4.350563 3.757374 3.636318 3.601366 3.539929 0.862651\n",
"8905 3.724967 4.673218 3.218182 3.765976 3.386151 2.954136 0.730686\n",
"8906 3.419836 4.188060 3.019416 3.307404 3.861704 3.746484 0.702885\n",
"\n",
"[8907 rows x 7 columns]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_results = pd.DataFrame(rst, columns=out_cols)\n",
"out_results"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"COL: PM2.5, MSE: 2.30E-01,RMSE: 0.4799,MAPE: 9.99 %,MAE: 0.3802,R_2: -2.1369\n",
"COL: PM10, MSE: 1.64E-01,RMSE: 0.4053,MAPE: 6.97 %,MAE: 0.3164,R_2: -2.1513\n",
"COL: SO2, MSE: 4.32E-01,RMSE: 0.6574,MAPE: 16.439999999999998 %,MAE: 0.5326,R_2: -2.0811\n",
"COL: NO2, MSE: 1.48E-01,RMSE: 0.3843,MAPE: 8.52 %,MAE: 0.3095,R_2: -2.3884\n",
"COL: O3, MSE: 4.99E-01,RMSE: 0.7061,MAPE: 17.419999999999998 %,MAE: 0.5898,R_2: -2.0369\n",
"COL: O3_8h, MSE: 4.19E-01,RMSE: 0.6471,MAPE: 15.73 %,MAE: 0.5331,R_2: -1.936\n",
"COL: CO, MSE: 3.39E-02,RMSE: 0.1842,MAPE: 18.75 %,MAE: 0.1439,R_2: -2.1239\n"
]
}
],
"source": [
"for col in out_cols:\n",
" MSE = mean_squared_error(out_results[col].values, test_y[col].values)\n",
" RMSE = np.sqrt(MSE)\n",
" MAE = mean_absolute_error(out_results[col].values, test_y[col].values)\n",
" MAPE = mean_absolute_percentage_error(out_results[col].values, test_y[col].values)\n",
" R_2 = r2_score(out_results[col].values, test_y[col].values)\n",
" print(f\"COL: {col}, MSE: {format(MSE, '.2E')}\", end=',')\n",
" print(f'RMSE: {round(RMSE, 4)}', end=',')\n",
" print(f'MAPE: {round(MAPE, 4) * 100} %', end=',')\n",
" print(f'MAE: {round(MAE, 4)}', end=',')\n",
" print(f'R_2: {round(R_2, 4)}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py37",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "993bd31d5df1020fab369d79a34ff0a2a159e1798f3e25d3ad4b7751d38184c9"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}