{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import catboost\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from catboost import CatBoostRegressor" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PM2.5PM10SO2NO2O3O3_8hCOwdwsair_temp...PM2.5_transportationPM2.5_resdientPM2.5_powerpre_PM2.5pre_PM10pre_SO2pre_NO2pre_O3pre_O3_8hpre_CO
04.7449325.1761505.7235853.6635622.1972252.3025851.51512758.00.7-11.1...0.0812480.8271100.418587136.0214.0317.038.08.09.03.71
14.5849675.0434255.7268483.6375862.0794422.1972251.506297185.00.5-11.7...0.0883130.8271100.412773114.0176.0305.038.08.09.03.55
24.4773374.9558275.7589023.6635622.0794422.1972251.5151270.00.2-12.7...0.0912560.8271100.42440097.0154.0306.037.07.08.03.51
34.4543474.9416425.6801733.6375862.0794422.1972251.530395199.01.4-10.9...0.0924341.7461210.45928287.0141.0316.038.07.08.03.55
44.6728295.1239645.7589023.6375862.1972252.1972251.605430359.01.2-12.3...0.1707383.4462920.51451385.0139.0292.037.07.08.03.62
\n", "

5 rows × 49 columns

\n", "
" ], "text/plain": [ " PM2.5 PM10 SO2 NO2 O3 O3_8h CO \\\n", "0 4.744932 5.176150 5.723585 3.663562 2.197225 2.302585 1.515127 \n", "1 4.584967 5.043425 5.726848 3.637586 2.079442 2.197225 1.506297 \n", "2 4.477337 4.955827 5.758902 3.663562 2.079442 2.197225 1.515127 \n", "3 4.454347 4.941642 5.680173 3.637586 2.079442 2.197225 1.530395 \n", "4 4.672829 5.123964 5.758902 3.637586 2.197225 2.197225 1.605430 \n", "\n", " wd ws air_temp ... PM2.5_transportation PM2.5_resdient \\\n", "0 58.0 0.7 -11.1 ... 0.081248 0.827110 \n", "1 185.0 0.5 -11.7 ... 0.088313 0.827110 \n", "2 0.0 0.2 -12.7 ... 0.091256 0.827110 \n", "3 199.0 1.4 -10.9 ... 0.092434 1.746121 \n", "4 359.0 1.2 -12.3 ... 0.170738 3.446292 \n", "\n", " PM2.5_power pre_PM2.5 pre_PM10 pre_SO2 pre_NO2 pre_O3 pre_O3_8h \\\n", "0 0.418587 136.0 214.0 317.0 38.0 8.0 9.0 \n", "1 0.412773 114.0 176.0 305.0 38.0 8.0 9.0 \n", "2 0.424400 97.0 154.0 306.0 37.0 7.0 8.0 \n", "3 0.459282 87.0 141.0 316.0 38.0 7.0 8.0 \n", "4 0.514513 85.0 139.0 292.0 37.0 7.0 8.0 \n", "\n", " pre_CO \n", "0 3.71 \n", "1 3.55 \n", "2 3.51 \n", "3 3.55 \n", "4 3.62 \n", "\n", "[5 rows x 49 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv('./data/train_data_mod.csv')\n", "data.head()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'O3_8h', 'CO'], dtype='object')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_cols = data.columns[7:]\n", "out_cols = data.columns[:7]\n", "out_cols" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "\n", "train_X, test_X, train_y, test_y = train_test_split(data[feature_cols], data[out_cols], test_size=0.2,\n", " random_state=42)\n", "#准备参数\n", "other_params = {'learning_rate': 0.01, 'n_estimators': 300, 'max_depth': 5, 'min_child_weight': 1, 'seed': 0,\n", " 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0, 'reg_alpha': 0, 'reg_lambda': 1}\n", "\n", "params_gbm = {\n", " 'task': 'train',\n", " 'boosting_type': 'gbdt', # 设置提升类型\n", " 'objective': 'l1', # 目标函数\n", " 'metric': 'rmse', # 评估函数\n", " 'max_depth': 10,\n", " 'num_leaves': 20, # 叶子节点数\n", " 'learning_rate': 0.09, # 学习速率\n", " 'feature_fraction': 0.9, # 建树的特征选择比例\n", " 'bagging_fraction': 0.9, # 建树的样本采样比例\n", " 'bagging_freq': 10, # k 意味着每 k 次迭代执行bagging\n", " 'verbose': -1 # <0 显示致命的, =0 显示错误 (警告), >0 显示信息\n", "}" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import lightgbm as lgb\n", "from sklearn.multioutput import MultiOutputRegressor\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "base_model = lgb.LGBMRegressor(**params_gbm)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "base_cat = CatBoostRegressor(iterations=1000, learning_rate=0.0005, depth=10, loss_function='RMSE', eval_metric='RMSE', random_seed=99, od_type='Iter', od_wait=50, verbose=0)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "multioutputregressor = MultiOutputRegressor(base_cat).fit(train_X, train_y)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "rst = multioutputregressor.predict(test_X)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PM2.5PM10SO2NO2O3O3_8hCO
03.6531854.7002002.7223813.2615893.8364443.8571810.615219
14.3238874.9231963.1985024.0167523.1664743.5916390.824886
23.6601654.6623623.1369483.5137423.7639103.6717700.631061
33.7281124.6459583.5144113.7185473.1999073.2917500.862777
44.1896684.7434393.4456153.6748013.9490523.6952850.797655
........................
89024.2835304.9958994.0194443.9610543.6632943.3142310.916339
89033.6748664.6065043.4702833.3071483.6087393.6264630.860751
89043.7044094.3505633.7573743.6363183.6013663.5399290.862651
89053.7249674.6732183.2181823.7659763.3861512.9541360.730686
89063.4198364.1880603.0194163.3074043.8617043.7464840.702885
\n", "

8907 rows × 7 columns

\n", "
" ], "text/plain": [ " PM2.5 PM10 SO2 NO2 O3 O3_8h CO\n", "0 3.653185 4.700200 2.722381 3.261589 3.836444 3.857181 0.615219\n", "1 4.323887 4.923196 3.198502 4.016752 3.166474 3.591639 0.824886\n", "2 3.660165 4.662362 3.136948 3.513742 3.763910 3.671770 0.631061\n", "3 3.728112 4.645958 3.514411 3.718547 3.199907 3.291750 0.862777\n", "4 4.189668 4.743439 3.445615 3.674801 3.949052 3.695285 0.797655\n", "... ... ... ... ... ... ... ...\n", "8902 4.283530 4.995899 4.019444 3.961054 3.663294 3.314231 0.916339\n", "8903 3.674866 4.606504 3.470283 3.307148 3.608739 3.626463 0.860751\n", "8904 3.704409 4.350563 3.757374 3.636318 3.601366 3.539929 0.862651\n", "8905 3.724967 4.673218 3.218182 3.765976 3.386151 2.954136 0.730686\n", "8906 3.419836 4.188060 3.019416 3.307404 3.861704 3.746484 0.702885\n", "\n", "[8907 rows x 7 columns]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out_results = pd.DataFrame(rst, columns=out_cols)\n", "out_results" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "COL: PM2.5, MSE: 2.30E-01,RMSE: 0.4799,MAPE: 9.99 %,MAE: 0.3802,R_2: -2.1369\n", "COL: PM10, MSE: 1.64E-01,RMSE: 0.4053,MAPE: 6.97 %,MAE: 0.3164,R_2: -2.1513\n", "COL: SO2, MSE: 4.32E-01,RMSE: 0.6574,MAPE: 16.439999999999998 %,MAE: 0.5326,R_2: -2.0811\n", "COL: NO2, MSE: 1.48E-01,RMSE: 0.3843,MAPE: 8.52 %,MAE: 0.3095,R_2: -2.3884\n", "COL: O3, MSE: 4.99E-01,RMSE: 0.7061,MAPE: 17.419999999999998 %,MAE: 0.5898,R_2: -2.0369\n", "COL: O3_8h, MSE: 4.19E-01,RMSE: 0.6471,MAPE: 15.73 %,MAE: 0.5331,R_2: -1.936\n", "COL: CO, MSE: 3.39E-02,RMSE: 0.1842,MAPE: 18.75 %,MAE: 0.1439,R_2: -2.1239\n" ] } ], "source": [ "for col in out_cols:\n", " MSE = mean_squared_error(out_results[col].values, test_y[col].values)\n", " RMSE = np.sqrt(MSE)\n", " MAE = mean_absolute_error(out_results[col].values, test_y[col].values)\n", " MAPE = mean_absolute_percentage_error(out_results[col].values, test_y[col].values)\n", " R_2 = r2_score(out_results[col].values, test_y[col].values)\n", " print(f\"COL: {col}, MSE: {format(MSE, '.2E')}\", end=',')\n", " print(f'RMSE: {round(RMSE, 4)}', end=',')\n", " print(f'MAPE: {round(MAPE, 4) * 100} %', end=',')\n", " print(f'MAE: {round(MAE, 4)}', end=',')\n", " print(f'R_2: {round(R_2, 4)}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py37", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "993bd31d5df1020fab369d79a34ff0a2a159e1798f3e25d3ad4b7751d38184c9" } } }, "nbformat": 4, "nbformat_minor": 2 }