22-T67/keras_multi-attention_multi...

1747 lines
7.1 MiB
Plaintext
Raw Normal View History

2023-03-30 10:25:44 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_DEVICE_ORDER'] = 'PCB_BUS_ID'\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt\n",
"#新增加的两行\n",
"from pylab import mpl\n",
"# 设置显示中文字体\n",
"mpl.rcParams[\"font.sans-serif\"] = [\"SimHei\"]\n",
"\n",
"mpl.rcParams[\"axes.unicode_minus\"] = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>24_PM2.5</th>\n",
" <th>24_PM10</th>\n",
" <th>24_SO2</th>\n",
" <th>24_NO2</th>\n",
" <th>24_O3</th>\n",
" <th>24_CO</th>\n",
" <th>23_PM2.5</th>\n",
" <th>23_PM10</th>\n",
" <th>23_SO2</th>\n",
" <th>23_NO2</th>\n",
" <th>...</th>\n",
" <th>NH3_resdient</th>\n",
" <th>NH3_agricultural</th>\n",
" <th>VOC_industrial</th>\n",
" <th>VOC_transportation</th>\n",
" <th>VOC_resdient</th>\n",
" <th>VOC_power</th>\n",
" <th>PM2.5_industrial</th>\n",
" <th>PM2.5_transportation</th>\n",
" <th>PM2.5_resdient</th>\n",
" <th>PM2.5_power</th>\n",
" </tr>\n",
" <tr>\n",
" <th>date</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2015-01-03 01:00:00</th>\n",
" <td>136.0</td>\n",
" <td>214.0</td>\n",
" <td>317.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>3.71</td>\n",
" <td>114.0</td>\n",
" <td>176.0</td>\n",
" <td>305.0</td>\n",
" <td>38.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.359273</td>\n",
" <td>1.177423</td>\n",
" <td>1.084925</td>\n",
" <td>0.937173</td>\n",
" <td>0.037724</td>\n",
" <td>0.926851</td>\n",
" <td>0.077715</td>\n",
" <td>0.827110</td>\n",
" <td>0.436028</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 02:00:00</th>\n",
" <td>114.0</td>\n",
" <td>176.0</td>\n",
" <td>305.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>3.55</td>\n",
" <td>97.0</td>\n",
" <td>154.0</td>\n",
" <td>306.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.359273</td>\n",
" <td>1.177423</td>\n",
" <td>1.134240</td>\n",
" <td>0.937173</td>\n",
" <td>0.036215</td>\n",
" <td>0.926851</td>\n",
" <td>0.081248</td>\n",
" <td>0.827110</td>\n",
" <td>0.418587</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 03:00:00</th>\n",
" <td>97.0</td>\n",
" <td>154.0</td>\n",
" <td>306.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>3.51</td>\n",
" <td>87.0</td>\n",
" <td>141.0</td>\n",
" <td>316.0</td>\n",
" <td>38.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.327791</td>\n",
" <td>1.177423</td>\n",
" <td>1.232869</td>\n",
" <td>0.937173</td>\n",
" <td>0.035712</td>\n",
" <td>0.926851</td>\n",
" <td>0.088313</td>\n",
" <td>0.827110</td>\n",
" <td>0.412773</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 04:00:00</th>\n",
" <td>87.0</td>\n",
" <td>141.0</td>\n",
" <td>316.0</td>\n",
" <td>38.0</td>\n",
" <td>7.0</td>\n",
" <td>3.55</td>\n",
" <td>85.0</td>\n",
" <td>139.0</td>\n",
" <td>292.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.350014</td>\n",
" <td>1.177423</td>\n",
" <td>1.273965</td>\n",
" <td>0.937173</td>\n",
" <td>0.036718</td>\n",
" <td>0.926851</td>\n",
" <td>0.091256</td>\n",
" <td>0.827110</td>\n",
" <td>0.424400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 05:00:00</th>\n",
" <td>85.0</td>\n",
" <td>139.0</td>\n",
" <td>292.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>3.62</td>\n",
" <td>106.0</td>\n",
" <td>167.0</td>\n",
" <td>316.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.071588</td>\n",
" <td>0.388904</td>\n",
" <td>1.177423</td>\n",
" <td>1.290403</td>\n",
" <td>1.978475</td>\n",
" <td>0.039736</td>\n",
" <td>0.926851</td>\n",
" <td>0.092434</td>\n",
" <td>1.746121</td>\n",
" <td>0.459282</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 187 columns</p>\n",
"</div>"
],
"text/plain": [
" 24_PM2.5 24_PM10 24_SO2 24_NO2 24_O3 24_CO \\\n",
"date \n",
"2015-01-03 01:00:00 136.0 214.0 317.0 38.0 8.0 3.71 \n",
"2015-01-03 02:00:00 114.0 176.0 305.0 38.0 8.0 3.55 \n",
"2015-01-03 03:00:00 97.0 154.0 306.0 37.0 7.0 3.51 \n",
"2015-01-03 04:00:00 87.0 141.0 316.0 38.0 7.0 3.55 \n",
"2015-01-03 05:00:00 85.0 139.0 292.0 37.0 7.0 3.62 \n",
"\n",
" 23_PM2.5 23_PM10 23_SO2 23_NO2 ... NH3_resdient \\\n",
"date ... \n",
"2015-01-03 01:00:00 114.0 176.0 305.0 38.0 ... 0.033910 \n",
"2015-01-03 02:00:00 97.0 154.0 306.0 37.0 ... 0.033910 \n",
"2015-01-03 03:00:00 87.0 141.0 316.0 38.0 ... 0.033910 \n",
"2015-01-03 04:00:00 85.0 139.0 292.0 37.0 ... 0.033910 \n",
"2015-01-03 05:00:00 106.0 167.0 316.0 37.0 ... 0.071588 \n",
"\n",
" NH3_agricultural VOC_industrial VOC_transportation \\\n",
"date \n",
"2015-01-03 01:00:00 0.359273 1.177423 1.084925 \n",
"2015-01-03 02:00:00 0.359273 1.177423 1.134240 \n",
"2015-01-03 03:00:00 0.327791 1.177423 1.232869 \n",
"2015-01-03 04:00:00 0.350014 1.177423 1.273965 \n",
"2015-01-03 05:00:00 0.388904 1.177423 1.290403 \n",
"\n",
" VOC_resdient VOC_power PM2.5_industrial \\\n",
"date \n",
"2015-01-03 01:00:00 0.937173 0.037724 0.926851 \n",
"2015-01-03 02:00:00 0.937173 0.036215 0.926851 \n",
"2015-01-03 03:00:00 0.937173 0.035712 0.926851 \n",
"2015-01-03 04:00:00 0.937173 0.036718 0.926851 \n",
"2015-01-03 05:00:00 1.978475 0.039736 0.926851 \n",
"\n",
" PM2.5_transportation PM2.5_resdient PM2.5_power \n",
"date \n",
"2015-01-03 01:00:00 0.077715 0.827110 0.436028 \n",
"2015-01-03 02:00:00 0.081248 0.827110 0.418587 \n",
"2015-01-03 03:00:00 0.088313 0.827110 0.412773 \n",
"2015-01-03 04:00:00 0.091256 0.827110 0.424400 \n",
"2015-01-03 05:00:00 0.092434 1.746121 0.459282 \n",
"\n",
"[5 rows x 187 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('./new_train_data.csv', index_col='date')\n",
"# data.drop(columns=['wd'], inplace=True) # 风向还没想好怎么处理\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(37, 6)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_cols = ['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'CO']\n",
"feature_cols = [x for x in data.columns if x not in out_cols and x != 'date']\n",
"feature_cols = feature_cols[144:]\n",
"len(feature_cols), len(out_cols)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 09:43:47.876739: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow.keras.backend as K"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class TransformerBlock(layers.Layer):\n",
" def __init__(self, embed_dim, num_heads, ff_dim, name, rate=0.1):\n",
" super().__init__()\n",
" self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, name=name)\n",
" self.ffn = keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" )\n",
" self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.dropout1 = layers.Dropout(rate)\n",
" self.dropout2 = layers.Dropout(rate)\n",
"\n",
" def call(self, inputs, training):\n",
" attn_output = self.att(inputs, inputs)\n",
" attn_output = self.dropout1(attn_output, training=training)\n",
" out1 = self.layernorm1(inputs + attn_output)\n",
" ffn_output = self.ffn(out1)\n",
" ffn_output = self.dropout2(ffn_output, training=training)\n",
" return self.layernorm2(out1 + ffn_output)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.initializers import Constant"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Custom loss layer\n",
"class CustomMultiLossLayer(layers.Layer):\n",
" def __init__(self, nb_outputs=2, **kwargs):\n",
" self.nb_outputs = nb_outputs\n",
" self.is_placeholder = True\n",
" super(CustomMultiLossLayer, self).__init__(**kwargs)\n",
" \n",
" def build(self, input_shape=None):\n",
" # initialise log_vars\n",
" self.log_vars = []\n",
" for i in range(self.nb_outputs):\n",
" self.log_vars += [self.add_weight(name='log_var' + str(i), shape=(1,),\n",
" initializer=tf.initializers.he_normal(), trainable=True)]\n",
" super(CustomMultiLossLayer, self).build(input_shape)\n",
"\n",
" def multi_loss(self, ys_true, ys_pred):\n",
" assert len(ys_true) == self.nb_outputs and len(ys_pred) == self.nb_outputs\n",
" loss = 0\n",
" for y_true, y_pred, log_var in zip(ys_true, ys_pred, self.log_vars):\n",
" mse = (y_true - y_pred) ** 2.\n",
" pre = K.exp(-log_var[0])\n",
" loss += tf.abs(tf.reduce_logsumexp(pre * mse + log_var[0], axis=-1))\n",
" return K.mean(loss)\n",
"\n",
" def call(self, inputs):\n",
" ys_true = inputs[:self.nb_outputs]\n",
" ys_pred = inputs[self.nb_outputs:]\n",
" loss = self.multi_loss(ys_true, ys_pred)\n",
" self.add_loss(loss, inputs=inputs)\n",
" # We won't actually use the output.\n",
" return K.concatenate(inputs, -1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"num_heads, ff_dim = 3, 16"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def get_prediction_model():\n",
" def build_output(out, out_name):\n",
" self_block = TransformerBlock(64, num_heads, ff_dim, name=f'{out_name}_attn')\n",
" out = self_block(out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(32, activation=\"relu\")(out)\n",
" # out = layers.Dense(1, name=out_name, activation=\"sigmoid\")(out)\n",
" return out\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" x = layers.Conv1D(filters=64, kernel_size=1, activation='relu')(inputs)\n",
" # x = layers.Dropout(rate=0.1)(x)\n",
" lstm_out = layers.Bidirectional(layers.LSTM(units=64, return_sequences=True))(x)\n",
" lstm_out = layers.Dense(128, activation='relu')(lstm_out)\n",
" transformer_block = TransformerBlock(128, num_heads, ff_dim, name='first_attn')\n",
" out = transformer_block(lstm_out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(64, activation='relu')(out)\n",
" out = K.expand_dims(out, axis=1)\n",
"\n",
" pm25 = build_output(out, 'pm25')\n",
" pm10 = build_output(out, 'pm10')\n",
" so2 = build_output(out, 'so2')\n",
" no2 = build_output(out, 'no2')\n",
" o3 = build_output(out, 'o3')\n",
" co = build_output(out, 'co')\n",
"\n",
" merge = layers.Concatenate(axis=1)([pm25, pm10, so2, no2, o3, co])\n",
" merge = K.expand_dims(merge, axis=1)\n",
" merge_attn = TransformerBlock(32*6, 3, 16, name='last_attn')\n",
"\n",
" out = merge_attn(merge)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
"\n",
" pm25 = layers.Dense(32, activation='relu')(out)\n",
" pm10 = layers.Dense(32, activation='relu')(out)\n",
" so2 = layers.Dense(32, activation='relu')(out)\n",
" no2 = layers.Dense(32, activation='relu')(out)\n",
" o3 = layers.Dense(32, activation='relu')(out)\n",
" co = layers.Dense(32, activation='relu')(out)\n",
"\n",
" pm25 = layers.Dense(1, activation='sigmoid', name='pm25')(pm25)\n",
" pm10 = layers.Dense(1, activation='sigmoid', name='pm10')(pm10)\n",
" so2 = layers.Dense(1, activation='sigmoid', name='so2')(so2)\n",
" no2 = layers.Dense(1, activation='sigmoid', name='no2')(no2)\n",
" o3 = layers.Dense(1, activation='sigmoid', name='o3')(o3)\n",
" co = layers.Dense(1, activation='sigmoid', name='co')(co)\n",
"\n",
" model = Model(inputs=[inputs], outputs=[pm25, pm10, so2, no2, o3, co])\n",
" return model\n",
"\n",
"def get_trainable_model(prediction_model):\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" pm25, pm10, so2, no2, o3, co = prediction_model(inputs)\n",
" pm25_real = layers.Input(shape=(1,), name='pm25_real')\n",
" pm10_real = layers.Input(shape=(1,), name='pm10_real')\n",
" so2_real = layers.Input(shape=(1,), name='so2_real')\n",
" no2_real = layers.Input(shape=(1,), name='no2_real')\n",
" o3_real = layers.Input(shape=(1,), name='o3_real')\n",
" co_real = layers.Input(shape=(1,), name='co_real')\n",
" out = CustomMultiLossLayer(nb_outputs=6)([pm25_real, pm10_real, so2_real, no2_real, o3_real, co_real, pm25, pm10, so2, no2, o3, co])\n",
" return Model([inputs, pm25_real, pm10_real, so2_real, no2_real, o3_real, co_real], out)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"use_cols = feature_cols + out_cols\n",
"use_data = data[use_cols].dropna()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"for col in use_cols:\n",
" use_data[col] = use_data[col].astype('float32')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"maxs = use_data.max()\n",
"mins = use_data.min()\n",
"use_cols = use_data.columns"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"for col in use_cols:\n",
" # use_data[col] = use_data[col].apply(lambda x: 0 if x < 0 else x)\n",
" # use_data[col] = np.log1p(use_data[col])\n",
" use_data[col] = (use_data[col] - mins[col]) / (maxs[col] - mins[col])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid = train_test_split(use_data[use_cols], test_size=0.1, random_state=42, shuffle=True)\n",
"valid_data, test_data = train_test_split(valid, test_size=0.5, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 09:43:54.452903: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1\n",
"2023-03-30 09:43:54.512341: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal\n",
"2023-03-30 09:43:54.512378: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: ubuntu-NF5468M6\n",
"2023-03-30 09:43:54.512384: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: ubuntu-NF5468M6\n",
"2023-03-30 09:43:54.512489: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 510.47.3\n",
"2023-03-30 09:43:54.512510: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 510.47.3\n",
"2023-03-30 09:43:54.512515: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 510.47.3\n",
"2023-03-30 09:43:54.512816: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"prediction_model = get_prediction_model()\n",
"trainable_model = get_trainable_model(prediction_model)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input (InputLayer) [(None, 1, 37)] 0 \n",
"__________________________________________________________________________________________________\n",
"conv1d (Conv1D) (None, 1, 64) 2432 input[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional (Bidirectional) (None, 1, 128) 66048 conv1d[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense (Dense) (None, 1, 128) 16512 bidirectional[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block (TransformerB (None, 1, 128) 202640 dense[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d (Globa (None, 128) 0 transformer_block[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_2 (Dropout) (None, 128) 0 global_average_pooling1d[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_3 (Dense) (None, 64) 8256 dropout_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf.expand_dims (TFOpLambda) (None, 1, 64) 0 dense_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_1 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_2 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_3 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_4 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_5 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_6 (Transforme (None, 1, 64) 52176 tf.expand_dims[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_1 (Glo (None, 64) 0 transformer_block_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_2 (Glo (None, 64) 0 transformer_block_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_3 (Glo (None, 64) 0 transformer_block_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_4 (Glo (None, 64) 0 transformer_block_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_5 (Glo (None, 64) 0 transformer_block_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_6 (Glo (None, 64) 0 transformer_block_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_5 (Dropout) (None, 64) 0 global_average_pooling1d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_8 (Dropout) (None, 64) 0 global_average_pooling1d_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_11 (Dropout) (None, 64) 0 global_average_pooling1d_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_14 (Dropout) (None, 64) 0 global_average_pooling1d_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_17 (Dropout) (None, 64) 0 global_average_pooling1d_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_20 (Dropout) (None, 64) 0 global_average_pooling1d_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_6 (Dense) (None, 32) 2080 dropout_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_9 (Dense) (None, 32) 2080 dropout_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_12 (Dense) (None, 32) 2080 dropout_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_15 (Dense) (None, 32) 2080 dropout_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_18 (Dense) (None, 32) 2080 dropout_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_21 (Dense) (None, 32) 2080 dropout_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate (Concatenate) (None, 192) 0 dense_6[0][0] \n",
" dense_9[0][0] \n",
" dense_12[0][0] \n",
" dense_15[0][0] \n",
" dense_18[0][0] \n",
" dense_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf.expand_dims_1 (TFOpLambda) (None, 1, 192) 0 concatenate[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_7 (Transforme (None, 1, 192) 451408 tf.expand_dims_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_7 (Glo (None, 192) 0 transformer_block_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_23 (Dropout) (None, 192) 0 global_average_pooling1d_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_24 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_25 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_26 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_27 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_28 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_29 (Dense) (None, 32) 6176 dropout_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"pm25 (Dense) (None, 1) 33 dense_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"pm10 (Dense) (None, 1) 33 dense_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"so2 (Dense) (None, 1) 33 dense_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"no2 (Dense) (None, 1) 33 dense_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"o3 (Dense) (None, 1) 33 dense_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"co (Dense) (None, 1) 33 dense_29[0][0] \n",
"==================================================================================================\n",
"Total params: 1,110,086\n",
"Trainable params: 1,110,086\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"prediction_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import optimizers"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"trainable_model.compile(optimizer='adam', loss=None)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"X = np.expand_dims(train_data[feature_cols].values, axis=1)\n",
"Y = [x for x in train_data[out_cols].values.T]\n",
"Y_valid = [x for x in valid_data[out_cols].values.T]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"from keras.callbacks import ReduceLROnPlateau\n",
"reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 09:44:00.730399: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)\n",
"2023-03-30 09:44:00.750292: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2200000000 Hz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"690/690 [==============================] - 22s 20ms/step - loss: 2.0729 - val_loss: 1.2554\n",
"Epoch 2/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.7311 - val_loss: 0.3496\n",
"Epoch 3/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.1033 - val_loss: 0.0213\n",
"Epoch 4/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0228 - val_loss: 0.0188\n",
"Epoch 5/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0224 - val_loss: 0.0237\n",
"Epoch 6/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0217 - val_loss: 0.0196\n",
"Epoch 7/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0210 - val_loss: 0.0249\n",
"Epoch 8/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0214 - val_loss: 0.0212\n",
"Epoch 9/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0202 - val_loss: 0.0181\n",
"Epoch 10/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0197 - val_loss: 0.0221\n",
"Epoch 11/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0197 - val_loss: 0.0189\n",
"Epoch 12/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0191 - val_loss: 0.0176\n",
"Epoch 13/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0188 - val_loss: 0.0180\n",
"Epoch 14/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0185 - val_loss: 0.0176\n",
"Epoch 15/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0183 - val_loss: 0.0177\n",
"Epoch 16/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0181 - val_loss: 0.0167\n",
"Epoch 17/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0182 - val_loss: 0.0171\n",
"Epoch 18/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0177 - val_loss: 0.0169\n",
"Epoch 19/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0180 - val_loss: 0.0164\n",
"Epoch 20/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0173 - val_loss: 0.0163\n",
"Epoch 21/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0168 - val_loss: 0.0167\n",
"Epoch 22/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0172 - val_loss: 0.0167\n",
"Epoch 23/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0169 - val_loss: 0.0151\n",
"Epoch 24/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0169 - val_loss: 0.0178\n",
"Epoch 25/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0166 - val_loss: 0.0156\n",
"Epoch 26/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0163 - val_loss: 0.0170\n",
"Epoch 27/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0162 - val_loss: 0.0151\n",
"Epoch 28/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0167 - val_loss: 0.0147\n",
"Epoch 29/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0165 - val_loss: 0.0155\n",
"Epoch 30/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0159 - val_loss: 0.0152\n",
"Epoch 31/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0156 - val_loss: 0.0149\n",
"Epoch 32/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0161 - val_loss: 0.0150\n",
"Epoch 33/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0153 - val_loss: 0.0148\n",
"Epoch 34/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0156 - val_loss: 0.0160\n",
"Epoch 35/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0153 - val_loss: 0.0164\n",
"Epoch 36/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0151 - val_loss: 0.0147\n",
"Epoch 37/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0152 - val_loss: 0.0146\n",
"Epoch 38/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0152 - val_loss: 0.0144\n",
"Epoch 39/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0147 - val_loss: 0.0142\n",
"Epoch 40/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0149 - val_loss: 0.0141\n",
"Epoch 41/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0147 - val_loss: 0.0156\n",
"Epoch 42/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0145 - val_loss: 0.0144\n",
"Epoch 43/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0146 - val_loss: 0.0135\n",
"Epoch 44/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0143 - val_loss: 0.0184\n",
"Epoch 45/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0147 - val_loss: 0.0142\n",
"Epoch 46/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0144 - val_loss: 0.0136\n",
"Epoch 47/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0143 - val_loss: 0.0134\n",
"Epoch 48/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0139 - val_loss: 0.0136\n",
"Epoch 49/100\n",
"690/690 [==============================] - 12s 17ms/step - loss: 0.0143 - val_loss: 0.0151\n",
"Epoch 50/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0139 - val_loss: 0.0144\n",
"Epoch 51/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0141 - val_loss: 0.0144\n",
"Epoch 52/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0135 - val_loss: 0.0136\n",
"Epoch 53/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0136 - val_loss: 0.0137\n",
"Epoch 54/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0118 - val_loss: 0.0119\n",
"Epoch 55/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0115 - val_loss: 0.0118\n",
"Epoch 56/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0114 - val_loss: 0.0118\n",
"Epoch 57/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0114 - val_loss: 0.0117\n",
"Epoch 58/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0113 - val_loss: 0.0118\n",
"Epoch 59/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0112 - val_loss: 0.0118\n",
"Epoch 60/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0112 - val_loss: 0.0117\n",
"Epoch 61/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0112 - val_loss: 0.0118\n",
"Epoch 62/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0111 - val_loss: 0.0117\n",
"Epoch 63/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0111 - val_loss: 0.0117\n",
"Epoch 64/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0111 - val_loss: 0.0116\n",
"Epoch 65/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0110 - val_loss: 0.0115\n",
"Epoch 66/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0110 - val_loss: 0.0117\n",
"Epoch 67/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0110 - val_loss: 0.0116\n",
"Epoch 68/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0110 - val_loss: 0.0116\n",
"Epoch 69/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0109 - val_loss: 0.0116\n",
"Epoch 70/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0110 - val_loss: 0.0114\n",
"Epoch 71/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0109 - val_loss: 0.0115\n",
"Epoch 72/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0109 - val_loss: 0.0116\n",
"Epoch 73/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0108 - val_loss: 0.0115\n",
"Epoch 74/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0108 - val_loss: 0.0115\n",
"Epoch 75/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0108 - val_loss: 0.0116\n",
"Epoch 76/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0108 - val_loss: 0.0114\n",
"Epoch 77/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0107 - val_loss: 0.0115\n",
"Epoch 78/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0107 - val_loss: 0.0115\n",
"Epoch 79/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0107 - val_loss: 0.0115\n",
"Epoch 80/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0107 - val_loss: 0.0115\n",
"Epoch 81/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0105 - val_loss: 0.0113\n",
"Epoch 82/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 83/100\n",
"690/690 [==============================] - 14s 20ms/step - loss: 0.0104 - val_loss: 0.0114\n",
"Epoch 84/100\n",
"690/690 [==============================] - 14s 21ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 85/100\n",
"690/690 [==============================] - 13s 19ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 86/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 87/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 88/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 89/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 90/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 91/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 92/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 93/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 94/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n",
"Epoch 95/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 96/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 97/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 98/100\n",
"690/690 [==============================] - 12s 17ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 99/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0103 - val_loss: 0.0113\n",
"Epoch 100/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0104 - val_loss: 0.0113\n"
]
}
],
"source": [
"trainable_model.compile(optimizer='adam', loss=None)\n",
"hist = trainable_model.fit([X, Y[0], Y[1], Y[2], Y[3], Y[4], Y[5]], epochs=100, batch_size=64, verbose=1, \n",
" validation_data=[np.expand_dims(valid_data[feature_cols].values, axis=1), Y_valid[0], Y_valid[1], Y_valid[2], Y_valid[3], Y_valid[4], Y_valid[5]],\n",
" callbacks=[reduce_lr]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 10:05:02.545274: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n",
"WARNING:absl:Found untraced functions such as first_attn_layer_call_fn, first_attn_layer_call_and_return_conditional_losses, layer_normalization_layer_call_fn, layer_normalization_layer_call_and_return_conditional_losses, layer_normalization_1_layer_call_fn while saving (showing 5 of 450). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: ./models/uw_loss_looknow/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: ./models/uw_loss_looknow/assets\n"
]
}
],
"source": [
"prediction_model.save(f'./models/uw_loss_looknow/')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[0.05530462],\n",
" [0.15461421],\n",
" [0.06984487],\n",
" ...,\n",
" [0.11500913],\n",
" [0.0667375 ],\n",
" [0.2275947 ]], dtype=float32),\n",
" array([[0.04914668],\n",
" [0.10565668],\n",
" [0.0540773 ],\n",
" ...,\n",
" [0.12981808],\n",
" [0.06497446],\n",
" [0.13074341]], dtype=float32),\n",
" array([[0.0275037 ],\n",
" [0.03192595],\n",
" [0.02210939],\n",
" ...,\n",
" [0.04022893],\n",
" [0.02630579],\n",
" [0.04132178]], dtype=float32),\n",
" array([[0.06796426],\n",
" [0.20642754],\n",
" [0.28993258],\n",
" ...,\n",
" [0.4715118 ],\n",
" [0.09288657],\n",
" [0.5116445 ]], dtype=float32),\n",
" array([[0.25186226],\n",
" [0.35088968],\n",
" [0.09913486],\n",
" ...,\n",
" [0.03776854],\n",
" [0.47971994],\n",
" [0.05172443]], dtype=float32),\n",
" array([[0.06528944],\n",
" [0.12970403],\n",
" [0.08967546],\n",
" ...,\n",
" [0.10639769],\n",
" [0.0641529 ],\n",
" [0.18867406]], dtype=float32)]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rst = prediction_model.predict(np.expand_dims(test_data[feature_cols], axis=1))\n",
"rst"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.9995770752172775,\n",
" 0.9998096046393907,\n",
" 0.9998504109554951,\n",
" 0.9992896477643191,\n",
" 0.9997416699046467,\n",
" 0.9996395653611235]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[np.exp(K.get_value(log_var[0]))**0.5 for log_var in trainable_model.layers[-1].log_vars]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"pred_rst = pd.DataFrame.from_records(np.squeeze(np.asarray(rst), axis=2).T, columns=out_cols)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"real_rst = test_data[out_cols].copy()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"for col in out_cols:\n",
" pred_rst[col] = pred_rst[col] * (maxs[col] - mins[col]) + mins[col]\n",
" real_rst[col] = real_rst[col] * (maxs[col] - mins[col]) + mins[col]"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"y_pred_pm25 = pred_rst['PM2.5'].values.reshape(-1,)\n",
"y_pred_pm10 = pred_rst['PM10'].values.reshape(-1,)\n",
"y_pred_so2 = pred_rst['SO2'].values.reshape(-1,)\n",
"y_pred_no2 = pred_rst['NO2'].values.reshape(-1,)\n",
"y_pred_o3 = pred_rst['O3'].values.reshape(-1,)\n",
"y_pred_co = pred_rst['CO'].values.reshape(-1,)\n",
"y_true_pm25 = real_rst['PM2.5'].values.reshape(-1,)\n",
"y_true_pm10 = real_rst['PM10'].values.reshape(-1,)\n",
"y_true_so2 = real_rst['SO2'].values.reshape(-1,)\n",
"y_true_no2 = real_rst['NO2'].values.reshape(-1,)\n",
"y_true_o3 = real_rst['O3'].values.reshape(-1,)\n",
"y_true_co = real_rst['CO'].values.reshape(-1,)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def print_eva(y_true, y_pred, tp):\n",
" MSE = mean_squared_error(y_true, y_pred)\n",
" RMSE = np.sqrt(MSE)\n",
" MAE = mean_absolute_error(y_true, y_pred)\n",
" MAPE = mean_absolute_percentage_error(y_true, y_pred)\n",
" R_2 = r2_score(y_true, y_pred)\n",
" print(f\"COL: {tp}, MSE: {format(MSE, '.2E')}\", end=',')\n",
" print(f'RMSE: {round(RMSE, 4)}', end=',')\n",
" print(f'MAPE: {round(MAPE, 4) * 100} %', end=',')\n",
" print(f'MAE: {round(MAE, 4)}', end=',')\n",
" print(f'R_2: {round(R_2, 4)}')\n",
" return [MSE, RMSE, MAE, MAPE, R_2]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"COL: pm25, MSE: 6.07E+02,RMSE: 24.63920021057129,MAPE: 47.06000089645386 %,MAE: 17.84910011291504,R_2: 0.7534\n",
"COL: pm10, MSE: 2.18E+03,RMSE: 46.695701599121094,MAPE: 38.22999894618988 %,MAE: 32.543701171875,R_2: 0.6114\n",
"COL: so2, MSE: 6.40E+02,RMSE: 25.30109977722168,MAPE: 69.72000002861023 %,MAE: 15.265700340270996,R_2: 0.7953\n",
"COL: no2, MSE: 1.25E+02,RMSE: 11.198399543762207,MAPE: 23.669999837875366 %,MAE: 8.1274995803833,R_2: 0.7998\n",
"COL: o3, MSE: 1.39E+02,RMSE: 11.781700134277344,MAPE: 45.170000195503235 %,MAE: 8.724499702453613,R_2: 0.9443\n",
"COL: co, MSE: 9.53E-02,RMSE: 0.30869999527931213,MAPE: 21.150000393390656 %,MAE: 0.22669999301433563,R_2: 0.8096\n"
]
}
],
"source": [
"pm25_eva = print_eva(y_true_pm25, y_pred_pm25, tp='pm25')\n",
"pm10_eva = print_eva(y_true_pm10, y_pred_pm10, tp='pm10')\n",
"so2_eva = print_eva(y_true_so2, y_pred_so2, tp='so2')\n",
"nox_eva = print_eva(y_true_no2, y_pred_no2, tp='no2')\n",
"o3_eva = print_eva(y_true_o3, y_pred_o3, tp='o3')\n",
"co_eva = print_eva(y_true_co, y_pred_co, tp='co')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MSE</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>MAPE</th>\n",
" <th>R_2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>PM25</th>\n",
" <td>607.088562</td>\n",
" <td>24.639168</td>\n",
" <td>17.849148</td>\n",
" <td>0.470577</td>\n",
" <td>0.753417</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PM10</th>\n",
" <td>2180.489258</td>\n",
" <td>46.695709</td>\n",
" <td>32.543747</td>\n",
" <td>0.382252</td>\n",
" <td>0.611418</td>\n",
" </tr>\n",
" <tr>\n",
" <th>SO2</th>\n",
" <td>640.143799</td>\n",
" <td>25.301064</td>\n",
" <td>15.265686</td>\n",
" <td>0.697226</td>\n",
" <td>0.795343</td>\n",
" </tr>\n",
" <tr>\n",
" <th>NO2</th>\n",
" <td>125.403076</td>\n",
" <td>11.198352</td>\n",
" <td>8.127537</td>\n",
" <td>0.236683</td>\n",
" <td>0.799830</td>\n",
" </tr>\n",
" <tr>\n",
" <th>O3</th>\n",
" <td>138.808472</td>\n",
" <td>11.781701</td>\n",
" <td>8.724504</td>\n",
" <td>0.451706</td>\n",
" <td>0.944349</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CO</th>\n",
" <td>0.095273</td>\n",
" <td>0.308662</td>\n",
" <td>0.226732</td>\n",
" <td>0.211459</td>\n",
" <td>0.809550</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MSE RMSE MAE MAPE R_2\n",
"PM25 607.088562 24.639168 17.849148 0.470577 0.753417\n",
"PM10 2180.489258 46.695709 32.543747 0.382252 0.611418\n",
"SO2 640.143799 25.301064 15.265686 0.697226 0.795343\n",
"NO2 125.403076 11.198352 8.127537 0.236683 0.799830\n",
"O3 138.808472 11.781701 8.724504 0.451706 0.944349\n",
"CO 0.095273 0.308662 0.226732 0.211459 0.809550"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame.from_records([pm25_eva, pm10_eva, so2_eva, nox_eva, o3_eva,co_eva], columns=['MSE', 'RMSE', 'MAE', 'MAPE', 'R_2'], index=['PM25', 'PM10', 'SO2', 'NO2', 'O3', 'CO'])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f9e107bdc90>"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:matplotlib.font_manager:findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n",
"WARNING:matplotlib.font_manager:findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6UAAAIICAYAAACW1EjCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eZhsZ3nejd7vGmqu6mHv7j2jvTVtCQ1oM4ojYWOwLYQxKLKVyLEdLjsx/hxyovg4ctB3nBPHgcOgBPvLd46TDwcwOSZ2iC3LeABhrICDzWCBhAQakIS2pD32nnqoYc3v+eNdq+aqXlW9aujq+3dd+6ru6u6qtbur1nrv976f5xFSShBCCCGEEEIIIZNAm/QBEEIIIYQQQgjZuVCUEkIIIYQQQgiZGBSlhBBCCCGEEEImBkUpIYQQQgghhJCJQVFKCCGEEEIIIWRiUJQSQgghhBBCCJkYxqQPAAB2794tDx8+POnDIIQQQgghhBAyAr75zW+el1IudfvaVIjSw4cP45FHHpn0YRBCCCGEEEIIGQFCiBd7fY3xXUIIIYQQQgghE4OilBBCCCGEEELIxKAoJYQQQgghhBAyMaaiprQbruvixIkTsCxr0ocyUjKZDA4ePAjTNCd9KIQQQgghhBAydqZWlJ44cQLFYhGHDx+GEGLShzMSpJS4cOECTpw4gSNHjkz6cAghhBBCCCFk7ExtfNeyLOzatWtmBSkACCGwa9eumXeDCSGEEEIIIaQXUytKAcy0II3YCf9HQgghhBBCCOnFVIvSWeJLX/oS3vGOd0z6MAghhBBCCCFkqpjamtJBefDRk7j/oWdwarWG/fNZ3HvbUdxx7MDIn9f3fei6PvLnIYQQQgghhJBZJLZTKoTQhRCPCiH+LPx8UQjxl0KIZ8PbhabvvU8I8ZwQ4hkhxG2jOPBmHnz0JO574AmcXK1BAji5WsN9DzyBBx89uaXHPX78OK655hq8+93vxo033oif/MmfRLVaxeHDh/Ebv/EbuPXWW/E//sf/wBe+8AW88Y1vxKtf/WrcddddKJfLAIDPf/7zuOaaa3DrrbfigQceSOB/SgghhBBCCCGzxSBO6T0AngJQCj9/H4C/klJ+SAjxvvDzfyWEeCWAuwFcB2A/gC8KIa6WUvrDHuS//dPv4slT6z2//uhLq3D8oOW+muvjV//wcfz+N17q+jOv3F/Cv/nx6zZ97meeeQYf//jHccstt+Dnf/7n8du//dsA1CiXr3zlKzh//jzuvPNOfPGLX0Q+n8eHP/xhfPSjH8Wv/uqv4hd+4Rfw8MMP48orr8Q/+Af/YID/MSGEEEIIIYTsDGI5pUKIgwB+DMB/abr7XQA+FX78KQB3NN3/B1JKW0r5AoDnALw+kaPtQbsg3ez+QTh06BBuueUWAMDP/MzP4Ctf+QoA1EXm1772NTz55JO45ZZbcNNNN+FTn/oUXnzxRTz99NM4cuQIrrrqKggh8DM/8zNbPhZCCCGEEEIImTXiOqW/BeBXARSb7tsjpTwNAFLK00KI5fD+AwC+1vR9J8L7hmYzR/OWDz2Mk6u1jvsPzGfx33/xjVt56o7uuNHn+XwegJo1+iM/8iP4/d///Zbve+yxx9hZlxBCCCGEEEI2YVOnVAjxDgArUspvxnzMbkpMdnnc9wghHhFCPHLu3LmYD92de287iqzZ2mwoa+q497ajW3pcAHjppZfw1a9+FQDw+7//+7j11ltbvn7zzTfjb/7mb/Dcc88BAKrVKr73ve/hmmuuwQsvvIDnn3++/rOEEEIIIYQQQlqJE9+9BcA7hRDHAfwBgLcIIX4PwFkhxD4ACG9Xwu8/AeBQ088fBHCq/UGllB+TUr5WSvnapaWlLfwXgDuOHcAH77wBB+azEFAO6QfvvCGR7rvXXnstPvWpT+HGG2/ExYsX8Uu/9EstX19aWsLv/u7v4qd+6qdw44034uabb8bTTz+NTCaDj33sY/ixH/sx3Hrrrbjsssu2fCyEEEIIIYQQMmsIKTtMzN7fLMSbAfxLKeU7hBD3A7jQ1OhoUUr5q0KI6wD8N6g60v0A/grAVf0aHb32ta+VjzzySMt9Tz31FK699tpB/z+Jcvz4cbzjHe/Ad77znZE+zzT8XwkhhBBCCCFkVAghvimlfG23r21lTumHAHxGCPGPAbwE4C4AkFJ+VwjxGQBPAvAAvHcrnXcJIYQQQgghhMwuseeUAoCU8ktSyneEH1+QUr5VSnlVeHux6fs+IKW8Qkp5VEr5uaQPelwcPnx45C4pIYQQQgghhPRk4wzwyduBjbOTPpKRMZAoJYQQQgghhBAyRr78EeClrwFf/vCkj2RkbCW+SwghhBBCCCFkFLx/GfDsxuePfFz9M9LAr630/rltCJ1SQgghhBBCCJk27nkcuP4uQDPV50YGuOEu4J4nJntcI4CilBBCCCGEEEKmjeJeIF0EAk997tlAugQU90z2uEYARekIOXz4MM6fPz/pwyCEEEIIIYRsRyorwP5j6uPrfwIoz2azo9kSpSPsTCWlRBAEiT8uIYQQQgghhHTl7k8Dl/+g+vhNv6I+n0FmS5Qm3Jnq+PHjuPbaa/FP/+k/xatf/Wr8u3/37/C6170ON954I/7Nv/k39e+744478JrXvAbXXXcdPvaxjyXy3IQQQgghhBBSb3bkO5M9jhGyPbrvfu59wJk+Bb0v/Q0gZePzqDOVEMArbun+M3tvAG7/0KZP/cwzz+CTn/wk7rjjDvzhH/4hvvGNb0BKiXe+853467/+a/zAD/wAPvGJT2BxcRG1Wg2ve93r8BM/8RPYtWvXgP9JQgghhBBCCGnDs9St7072OEbIbDil+18H5JYAEf53hAbkl4ADr9vyQ1922WW4+eab8YUvfAFf+MIXcOzYMbz61a/G008/jWeffRYA8B//43/Eq171Ktx88814+eWX6/cTQgghhBBCyJaInNJgdkXp9nBKYzia+NNfBr71u6pVsu8A174TeMdHt/zU+XwegKopve+++/CLv/iLLV//0pe+hC9+8Yv46le/ilwuhze/+c2wLGvLz0sIIYQQQgghcGvqdobju7PhlAKqM9Vrfg74J19Utwl3prrtttvwiU98AuVyGQBw8uRJrKysYG1tDQsLC8jlcnj66afxta99LdHnJYQQQgghhOxg6jWl3mSPY4RsD6c0Ds2dqBJwSNv50R/9UTz11FN44xvfCAAoFAr4vd/7PbztbW/Df/7P/xk33ngjjh49iptvvjnx5yaEEEIIIYTsUOo1pbPrlM6OKB0Bhw8fxne+85365/fccw/uueeeju/73Oc+1/Xnjx8/PqpDI4QQQgghhOwEdkBN6ezEdwkhhBBCCCFk1mD3XUIIIYQQQgghE6NeU0pRSgghhBBCCCFk3OyAmtKpFqVSykkfwsjZCf9HQgghhBBCyJBQlE6OTCaDCxcuzLRok1LiwoULyGQykz4UQgghhBBCyDQSidKAI2HGzsGDB3HixAmcO3du0ocyUjKZDA4ePDjpwyCEEEIIIYRMI/Wa0tl1SqdWlJqmiSNHjkz6MAghhBBCCCFkcrD7LiGEEEIIIYSQiRAEDYeUopQQQgghhBBCyFjx7cbHAUUpIYQQQgghhJBxEkV3gZmuKaUoJYQQQgghhJBpxG0WpXRKCSGEEEIIIYSME4+ilBBCCCGEEELIpPCaakoZ3yWEEEIIIYQQMlaandLAm9xxjBiKUkIIIYQQQgiZRuiUEkIIIYQQQgiZGKwpJYQQQgghhBAyMVqcUopSQgghhBBCCCHjxKupWyPD+C4hhBBCCCGEkDETOaXpIhDQKSWEEEIIIYQQMk6imtJ0kfFdQgghhBBCCCFjptkppSglhBBCCCGEEDJWIqc0VWRNKSGEEEIIIYSQMdMc3w28yR7LCKEoJYQQQgghhJBpxLUAoQFmlk4pIYQQQgghhJAx41lqHIyeYk0pIYQQQgghhJAx49mAkQZ0k6KUEEIIIYQQQsiYqTulJuO7hBBCCCGEEELGjGc34rsBnVJCCCGEEEIIIeMkcko1g/FdQgghhBBCCCFjpl5TykZHhBBCCCGEEEL
"text/plain": [
"<Figure size 1152x648 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(16, 9))\n",
"plt.plot(pred_rst['PM10'].values[50:150], 'o-', label='pred')\n",
"plt.plot(real_rst['PM10'].values[50:150], '*-', label='real')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from statistics import mean\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import explained_variance_score,r2_score, median_absolute_error, mean_squared_error, mean_absolute_error\n",
"from scipy import stats\n",
"import numpy as np\n",
"from matplotlib import rcParams\n",
"config = {\"font.size\": 32,\"mathtext.fontset\":'stix'}\n",
"rcParams.update(config)\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"config = {\"font.size\": 32,\"mathtext.fontset\":'stix'}\n",
"rcParams.update(config)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def scatter_out_1(x, y, label, name): ## x,y为两个需要做对比分析的两个量。\n",
" # ==========计算评价指标==========\n",
" BIAS = mean(x - y)\n",
" MSE = mean_squared_error(x, y)\n",
" RMSE = np.power(MSE, 0.5)\n",
" R2 = r2_score(x, y)\n",
" MAE = mean_absolute_error(x, y)\n",
" EV = explained_variance_score(x, y)\n",
" print('==========算法评价指标==========')\n",
" print('Explained Variance(EV):', '%.3f' % (EV))\n",
" print('Mean Absolute Error(MAE):', '%.3f' % (MAE))\n",
" print('Mean squared error(MSE):', '%.3f' % (MSE))\n",
" print('Root Mean Squard Error(RMSE):', '%.3f' % (RMSE))\n",
" print('R_squared:', '%.3f' % (R2))\n",
" # ===========Calculate the point density==========\n",
" xy = np.vstack([x, y])\n",
" z = stats.gaussian_kde(xy)(xy)\n",
" # ===========Sort the points by density, so that the densest points are plotted last===========\n",
" idx = z.argsort()\n",
" x, y, z = x[idx], y[idx], z[idx]\n",
" def best_fit_slope_and_intercept(xs, ys):\n",
" m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs)))\n",
" b = mean(ys) - m * mean(xs)\n",
" return m, b\n",
" m, b = best_fit_slope_and_intercept(x, y)\n",
" regression_line = []\n",
" for a in x:\n",
" regression_line.append((m * a) + b)\n",
" fig,ax=plt.subplots(figsize=(12,9),dpi=400)\n",
" scatter=ax.scatter(x,y,marker='o',c=z*100,s=15,label='LST',cmap='Spectral_r')\n",
" cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='Frequency')\n",
" min_value = min(min(x), min(y))\n",
" max_value = max(max(x), max(y))\n",
"\n",
" plt.plot([min_value-5,max_value+5],[min_value-5,max_value+5],'black',lw=1.5) # 画的1:1线线的颜色为black线宽为0.8\n",
" plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线\n",
" plt.axis([min_value-5,max_value+5,min_value-5,max_value+5]) # 设置线的范围\n",
" plt.xlabel('Measured %s' % label)\n",
" plt.ylabel('Retrived %s' % label)\n",
" # plt.xticks(fontproperties='Times New Roman')\n",
" # plt.yticks(fontproperties='Times New Roman')\n",
"\n",
"\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.95), '$N=%.f$' % len(y)) # text的位置需要根据x,y的大小范围进行调整。\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.88), '$R^2=%.2f$' % R2)\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.81), '$RMSE=%.2f$' % RMSE)\n",
" plt.xlim(min_value-5,max_value+5) # 设置x坐标轴的显示范围\n",
" plt.ylim(min_value-5,max_value+5) # 设置y坐标轴的显示范围\n",
" # file_name = name.split('(')[0].strip()\n",
" plt.savefig(f'./figure/looknow/{name}.png',dpi=800,bbox_inches='tight',pad_inches=0)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.611\n",
"Mean Absolute Error(MAE): 32.544\n",
"Mean squared error(MSE): 2180.489\n",
"Root Mean Squard Error(RMSE): 46.696\n",
"R_squared: 0.611\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:matplotlib.font_manager:findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n",
"WARNING:matplotlib.font_manager:findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEjAAAAyVCAYAAAAv6W6kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzde5iVZbk/8PsdBlAQRjRQMfFQGkZWIBjhVkQBwTzl7rDboKaAINavsyVYuU2wMjtYikBqJlLN5DiKGxhxhKFQ1MQ0pTzkgUzGUkeUg7qGWb8/bFcDw7CGWWu9M/D5XNdcXu9zP/fzfucACLPmfpNsNhsAAAAAAAAAAAAAAAAAAAAAAACtUZJ2AAAAAAAAAAAAAAAAAAAAAAAAoOMxwAgAAAAAAAAAAAAAAAAAAAAAAGg1A4wAAAAAAAAAAAAAAAAAAAAAAIBWM8AIAAAAAAAAAAAAAAAAAAAAAABoNQOMAAAAAAAAAAAAAAAAAAAAAACAVjPACAAAAAAAAAAAAAAAAAAAAAAAaDUDjAAAAAAAAAAAAAAAAAAAAAAAgFYzwAgAAAAAAAAAAAAAAAAAAAAAAGg1A4wAAAAAAAAAAAAAAAAAAAAAAIBWM8AIAAAAAAAAAAAAAAAAAAAAAABoNQOMAAAAAAAAAAAAAAAAAAAAAACAVjPACAAAAAAAAAAAAAAAAAAAAAAAaDUDjAAAAAAAAAAAAAAAAAAAAAAAgFYzwAgAAAAAAAAAAAAAAAAAAAAAAGg1A4wAAAAAAAAAAAAAAAAAAAAAAIBWM8AIAAAAAAAAAAAAAAAAAAAAAABoNQOMAAAAAAAAAAAAAAAAAAAAAACAVjPACAAAAAAAAAAAAAAAAAAAAAAAaDUDjAAAAAAAAAAAAAAAAAAAAAAAgFYzwAgAAAAAAAAAAAAAAAAAAAAAAGg1A4wAAAAAAAAAAAAAAAAAAAAAAIBWM8AIAAAAAAAAAAAAAAAAAAAAAABoNQOMAAAAAAAAAAAAAAAAAAAAAACAVjPACAAAAAAAAAAAAAAAAAAAAAAAaDUDjAAAAAAAAAAAAAAAAAAAAAAAgFYzwAgAAAAAAAAAAAAAAAAAAAAAAGg1A4wAAAAAAAAAAAAAAAAAAAAAAIBWM8AIAAAAAAAAAAAAAAAAAAAAAABoNQOMAAAAAAAAAAAAAAAAAAAAAACAVjPACAAAAAAAAAAAAAAAAAAAAAAAaDUDjAAAAAAAAAAAAAAAAAAAAAAAgFYzwAgAAAAAAAAAAAAAAAAAAIooSZIk7QwAAAD5YIARAAAAAAAAAAAAAAAAAAAU17i0AwAAAORDks1m084AAAAAAAAAAAAAAAAAAAC7hCRJOkfECxFxRDabrUs7DwAAQFuUpB0AAAAAAAAAAAAAAAAAAAB2ISdExDsi4oy0gwAAALSVAUYAAAAAAAAAAAAAAAAAAFA8n/jHfz+eagoAAIA8SLLZbNoZAADYAUmS7BcRJ2+x/HREbEwhDgAAAAAAAAAAAAAA0P51i4hDtli7I5vNrk0jzK4oSZLOEfFiRPSKiMaI2D+bzdalmwoAAGDHlaYdAACAHXZyRMxJOwQAAAAAAAAAAAAAANChnRcRc9MOsQs5Id4eXhQRURIRZ0TENenFAQAAaJuStAMAAAAAAAAAAAAAAAAAAMAu4hNbXH88lRQAAAB5Upp2AAAAAAAAAAAAAAAAAAAAyKNs2gGak8lkolevXlFfX//PtZKSkuPq6uqy++67b4rJWpSkHQAAAGjfStIOAAAAAAAAAAAAAAAAAAAAO7uampomw4siIhobG6OysjKlRAAAAG1XmnYAAAB22NNbLsyePTuOOOKINLIAAAAAAAAAAAAAAAApa2hoiO9973tx++23t6Ztq59PoDDKy8ubXa+oqIipU6cWOQ0AAEB+JNlsNu0MAADsgCRJPhwR9/z72j333BMf/vCHU0oEAAAAAAAAAAAAAACk5fXXX49PfvKTsWjRota2Dstms/cWIlOK2t0Pz2Yymdhnn32ivr5+q1pJSUn89a9/jX333TeFZNuVpB0AAABo30rSDgAAAAAAAAAAAAAAAAAAwI5bu3ZtDB8+fEeGF1EkNTU1zQ4viohobGyMysrKIicCAADIDwOMAAAAAAAAAAAAAAAAAAA6qMceeyyGDh0aDz30UNpRaEF5eXmL9YqKiiIlAQAAyC8DjAAAAAAAAAAAAAAAAAAAOqClS5fG0UcfHWvWrEk7Ci3IZDJRVVXV4p7ly5dHXV1dcQIBAADkkQFGAAAAAAAAAAAAAAAAAAAdzM033xwnnnhirFu3Lqf9nTt3jm9+85sFTkVzampqor6+vsnaoOjd5LqxsTEqKyuLGQsAACAvDDACAAAAAAAAAAAAAAAAAOggstlszJw5M8aPHx+ZTCannrKysqiuro4TTzyxwOloTnl5eZPrJCLGxWHRPUqbrFdUVBQxFQAAQH4YYAQAAAAAAAAAAAAAAAAA0AE0NDTE5MmTY/r06Tn39OvXL1asWBEjRowoYDK2JZPJRFVVVZO1/tEreiVdY1D0brK+fPnyqKurK2I6AACAtjPACAAAAAAAAAAAAAAAAACgnXv99dfj1FNPjblz5+bcM3DgwFi5cmUMGDCggMloSU1NTdTX1zdZGxx9IiJiyD/++38aGxujsrKyaNkAAADywQAjAAAAAAAAAAAAAAAAAIB2bO3atTF8+PBYtGhRzj1jx46N5cuXx3777VfAZGxPeXl5k+skIo6M3hER0T96RfcobVKvqKgoVjQAAIC8MMAIAAAAAAAAAAAAAAAAAKCdeuyxx2Lo0KHx0EMP5dwzadKkuP3222OPPfYoYDK2J5PJRFVVVZO1/tEreiZdIiKiNCmJQf8YZvR/li9fHnV1dcWKCAAA0GYGGAEAAAAAAAAAAAAAAAAAtENLly6No48+OtasWZNzz8yZM2P27NlRWlpawGTkoqamJurr65usDY4+Ta6HbHHd2NgYlZWVBc8GAACQLwYYAQAAAAAAAAAAAAAAAAC0MzfffHOceOKJsW7dupz2d+7cOebNmxcXXXRRJElS4HTkory8vMl1EhFHRu8ma/2jV3SPpsOmKioqCh0NAAAgbwwwAgAAAAAAAAAAAAAAAABoJ7LZbMycOTPGjx8fmUwmp56ysrKorq6OcePGFTgducpkMlFVVdVkrX/0ip5JlyZrpUlJDNpiqNHy5cujrq6u0BEBAADywgAjAAAAAAAAAAAAAAAAAIB2oKGhISZPnhzTp0/Puadfv36xYsWKGDFiRAGT0Vo1NTVRX1/fZG1w9Gl275At1hsbG6OysrJg2QAAAPLJACMAAAAAAAAAAAAAAAAAgJS9/vrrceqpp8bcuXNz7hk4cGCsXLkyBgwYUMBk7Ijy8vIm10lEHBm9m93bP3pF9yhtslZRUVGoaAAAAHllgBEAAAAAAAAAAAAAAAAAQIrWrl0bw4cPj0WLFuXcM3bs2Fi+fHnst99+BUzGjshkMlFVVdVkrX/0ip5Jl2b3lyYlMWiL4UbLly+Purq6QkUEAADIGwOMAAAAAAAAAAAAAAAAAABS8thjj8XQoUPjoYceyrln0qRJcfvtt8cee+xRwGTsqJqamqivr2+yNjj6tNgzZIt6Y2NjVFZW5j0bAABAvhlgBAAAAAAAAAAAAAAAAACQgqVLl8bRRx8da9asybln5syZMXv27CgtLS1gMtqivLy8yXUSEUdG7xZ7+kev6B5NP6cVFRX5jgYAAJB3BhgBAAAAAAAAAAAAAAAAABTZzTffHCeeeGKsW7cup/2dO3eOefPmxUUXXRRJkhQ4HTsqk8lEVVVVk7X+0St6Jl1a7CtNSmLQFkOOli9fHnV1dfmOCAAAkFcGGAEAAAAAAAAAAAAAAAAAFEk2m42ZM2fG+PHjI5PJ5NRTVlYW1dXVMW7cuAKno61qamqivr6+ydrg6JNT75At9jU2NkZlZWXesgEAABSCAUYAAAAAAAAAAAAAAAAAAEXQ0NAQkydPjunTp+fc069fv1ixYkWMGDGigMnIl/Ly8ibXSUQcGb1z6u0fvaJ7lDZZq6ioyFc0AACAgjD
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['PM10'].values, pred_rst['PM10'].values, label='$PM_{10}\\ (\\mu g/m^3$)', name='PM10')"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.754\n",
"Mean Absolute Error(MAE): 17.849\n",
"Mean squared error(MSE): 607.089\n",
"Root Mean Squard Error(RMSE): 24.639\n",
"R_squared: 0.753\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdeXxU1f3/8fcJCSABYlAQsAJitUFcGhZFrMYoIGkrLl1sv4krIBhqa7XaAra1VqBqta0LCChqBWoTG6O0QAgBEhvBDXfUiohUTVA0RsLmhJzfH/jTTkjCZObceyeT1/PxyKOPueee933PXAoYMp8x1loBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOCnpKALAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADaH4bfAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8x/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIDvGH4DAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPAdw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAL5j+A0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwHcMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+I7hNwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA3zH8BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADgO4bfAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8x/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIDvGH4DAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPAdw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAL5j+A0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwHcMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+I7hNwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA3zH8BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADgO4bfAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8x/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIDvGH4DAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPAdw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAL5j+A0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwHcMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+I7hNwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA3zH8BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADgO4bfAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8x/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIDvGH4DAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPAdw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAL5j+A0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwHcMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+I7hNwAAAAAAAAAAAAAAAAAAAAAAAAAAAICPjDEm6A4AAABAPGD4DQAAAAAAAAAAAAAAAAAAAAAAAAAAAOCv3KALAAAAAPHAWGuD7gAAAAAAAAAAAAAAAAAAAAAAAAAAAAC0C8aYFEkfSDreWlsddB8AAAAgSElBFwAAAAAAAAAAAAAAAAAAAAAAAAAAAADakbMkHSrpgqCLAAAAAEFj+A0AAAAAAAAAAAAAAAAAAAAAAAAAAADgnx9+8b8/CLQFAAAAEAeMtTboDgAAAIiCMaaPpO82OrxJ0s4A6gAAAAAAAAAAAAAAAAAAAAAAgPjXRdLARsf+aa2tCqJMe2SMSZG0VVK6pAZJh1trq4NtBQAAAAQnOegCAAAAiNp3Jc0LugQAAAAAAAAAAAAAAAAAAAAAAGjTrpA0P+gS7chZ2jf4RpKSJF0gaXZwdQAAAIBgJQVdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGgnftjo8Q8CaQEAAADEieSgCwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAO2aALNCUUCik9PV01NTVfHktKSjqjurra9u7dO8BmLTJBFwAAAEBiSwq6AAAAAAAAAAAAAAAAAAAAAAAAAAAAAJDoysrKwgbfSFJDQ4OKiooCagQAAAAELznoAgAAAIjapsYH5s6dq+OPPz6ILgAAAAAAAAAAAAAAAAAAAAAAIGD19fX64x//qCeeeKI12/Z7fwK8UVBQ0OTxwsJC5efn+9wGAAAAiA/GWht0BwAAAETBGHOKpKf+99hTTz2lU045JaBGAAAAAAAAAAAAAAAAAAAAAAAgKNu3b9eFF16oZcuWtXbrSGvtWi86BSju3jwbCoV02GGHqaamZr+1pKQkvf/+++rdu3cAzQ7IBF0AAAAAiS0p6AIAAAAAAAAAAAAAAAAAAAAAAAAAAACIXlVVlbKysqIZfAOflJWVNTn4RpIaGhpUVFTkcyMAAAAgPjD8BgAAAAAAAAAAAAAAAAAAAAAAAAAAoI167bXXNGLECL3wwgtBV0ELCgoKWlwvLCz0qQkAAAAQXxh+AwAAAAAAAAAAAAAAAAAAAAAAAAAA0AatXr1ap556qrZs2RJ0FbQgFAqpuLi4xXMqKipUXV3tTyEAAAAgjjD8BgAAAAAAAAAAAAAAAAAAAAAAAAAAoI1ZtGiRzj77bNXW1kZ0fkpKin7729963ApNKSsrU01NTdixIeoZ9rihoUFFRUV+1gIAAADiAsNvAAAAAAAAAAAAAAAAAAAAAAAAAAAA2ghrrWbOnKm8vDyFQqGI9qSlpamkpERnn322x+3QlIKCgrDHRlKujlGqksOOFxYW+tgKAAAAiA8MvwEAAAAAAAAAAAAAAAAAAAAAAAAAAGgD6uvrNWnSJE2fPj3iPf369VNlZaWys7M9bIbmhEIhFRcXhx3LULrSTScNUc+w4xUVFaqurvaxHQAAABA8ht8AAAAAAAAAAAAAAAAAAAAAAAAAAADEue3bt2vcuHGaP39+xHsyMzO1bt06DR482MNmaElZWZlqamrCjg1TL0nS8C/+9/9raGhQUVGRb90AAACAeMDwGwAAAAAAAAAAAAAAAAAAAAAAAAAAgDhWVVWlrKwsLVu2LOI9OTk5qqioUJ8+fTxshgMpKCgIe2wkDVVPSVKG0pWq5LD1wsJCv6oBAAAAcYHhNwAAAAAAAAAAAAAAAAAAAAAAAAAAAHHqtdde04gRI/TCCy9EvGfixIl64okn1LVrVw+b4UBCoZCKi4vDjmUoXd1NR0lSsknSkC8G4fx/FRUVqq6u9qsiAAAAEDiG3wAAAAAAAAAAAAAAAAAAAAAAAAAAAMSh1atX69RTT9WWLVsi3jNz5kzNnTtXycnJHjZDJMrKylRTUxN2bJh6hT0e3uhxQ0ODioqKPO8GAAAAxAuG3wAAAAAAAAAAAAAAAAAAAAAAAAAAAMSZRYsW6eyzz1ZtbW1E56ekpGjhwoWaOnWqjDEet0MkCgoKwh4bSUPVM+xYhtKVqvBBRYWFhV5XAwAAAOIGw28AAAAAAAAAAAAAAAAAAAAAAAAAAADihLVWM2fOVF5enkKhUER70tLSVFJSotzcXI/bIVKhUEjFxcVhxzKUru6mY9ixZJOkIY0G4lRUVKi6utrrigAAAEBcYPgNAAAAAAAAAAAAAAAAAAAAAAAAAABAHKivr9ekSZM0ffr0iPf069dPlZWVys7O9rAZWqusrEw1NTVhx4apV5PnDm90vKGhQUVFRZ51AwAAAOIJw28AAAAAAAAAAAAAAAAAAAAAAAAAAAACtn37do0bN07z58+PeE9mZqbWrVunwYMHe9gM0SgoKAh7bCQNVc8mz81QulKVHHassLDQq2oAAABAXGH4DQAAAAAAAAAAAAAAAAAAAAAAAAAAQICqqqqUlZWlZcuWRbwnJydHFRUV6tOnj4fNEI1QKKTi4uKwYxlKV3fTscnzk02ShjQajFNRUaH
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['PM2.5'].values, pred_rst['PM2.5'].values, label='$PM_{2.5} (\\mu g/m^3$)', name='PM25')"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.797\n",
"Mean Absolute Error(MAE): 15.266\n",
"Mean squared error(MSE): 640.144\n",
"Root Mean Squard Error(RMSE): 25.301\n",
"R_squared: 0.795\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzde5jVZbk//nsNM6AgjGigYOIpDSMtEI1wK6KAYJ62u8PuO6ipKDjWt7MlWrlNsDLbZSoCqZlCNVPjKAaMOMJgKGpippKnPJA5g6cROcoaZv3+2L9vu+EwrGHWWp8ZeL2uay6v9dzPcz/vGREvmDX3J5XJZAIAAAAAAAAAAAAAAAAAAAAAAAqpKOkAAAAAAAAAAAAAAAAAAAAAAADsegy/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAAAAAACg4Ay/AQAAAAAAAAAAAAAAAACAAkqlUqmkMwAAQEdg+A0AAAAAAAAAAAAAAAAAABRWWdIBAACgI0hlMpmkMwAAAAAAAAAAAAAAAAAAwC4hlUqVRMTrEXFEJpNpSDoPAAAkqSjpAAAAAAAAAAAAAAAAAAAAsAs5KSI+EBFnJR0EAACSZvgNAAAAAAAAAAAAAAAAAAAUzmf//39+JtEUAADQAaQymUzSGQAA2AGpVKpfRJy62fJLEbEugTgAAAAAAAAAAAAAAEDH1z0iDt5s7d5MJlOfRJhdUSqVKomIlRHROyKaI2K/TCbTkGwqAABITnHSAQAA2GGnRsSMpEMAAAAAAAAAAAAAAACd2kURMTPpELuQk+J/Bt9ERBRFxFkRcVNycQAAIFlFSQcAAAAAAAAAAAAAAAAAAIBdxGc3e/2ZRFIAAEAHUZx0AAAAAAAAAAAAAAAAAAAAyKFM0gG2Jp1OR+/evaOxsfGfa0VFRSc0NDRk9t133wSTtSqVdAAAAHZuRUkHAAAAAAAAAAAAAAAAAACAnV1tbW2LwTcREc3NzVFVVZVQIgAASF5x0gEAANhhL22+MH369DjiiCOSyAIAAAAAAAAAAAAAACSsqakpfvzjH8c999zTlmNb/HwC+VFRUbHV9crKyigvLy9wGgAA6BhSmUwm6QwAAOyAVCr1yYh46F/XHnroofjkJz+ZUCIAAAAAAAAAAAAAACApq1evjs997nMxb968th4dnslkHs5HpgR1uB+eTafTsc8++0RjY+MWtaKiovjHP/4R++67bwLJtiuVdAAAAHZuRUkHAAAAAAAAAAAAAAAAAABgx9XX18eIESN2ZPANBVJbW7vVwTcREc3NzVFVVVXgRAAA0DEYfgMAAAAAAAAAAAAAAAAA0Ek988wzMWzYsHjiiSeSjkIrKioqWq1XVlYWKAkAAHQsht8AAAAAAAAAAAAAAAAAAHRCCxcujGOPPTZWrFiRdBRakU6no7q6utU9ixcvjoaGhsIEAgCADsTwGwAAAAAAAAAAAAAAAACATmbWrFlx8sknx6pVq7LaX1JSEt/73vfynIqtqa2tjcbGxhZrQ6JPi9fNzc1RVVVVyFgAANAhGH4DAAAAAAAAAAAAAAAAANBJZDKZmDp1aowfPz7S6XRWZ0pLS6OmpiZOPvnkPKdjayoqKlq8TkVEWRwWPaK4xXplZWUBUwEAQMdg+A0AAAAAAAAAAAAAAAAAQCfQ1NQUEydOjMsvvzzrMwMGDIglS5bEyJEj85iMbUmn01FdXd1ibWD0jt6pbjEk+rRYX7x4cTQ0NBQwHQAAJM/wGwAAAAAAAAAAAAAAAACADm716tVx+umnx8yZM7M+M3jw4Fi6dGkMGjQoj8loTW1tbTQ2NrZYGxp9IyLi6P//n/9Pc3NzVFVVFSwbAAB0BIbfAAAAAAAAAAAAAAAAAAB0YPX19TFixIiYN29e1mfGjRsXixcvjn79+uUxGdtTUVHR4nUqIo6KPhERMTB6R48oblGvrKwsVDQAAOgQDL8BAAAAAAAAAAAAAAAAAOignnnmmRg2bFg88cQTWZ+58MIL45577ok99tgjj8nYnnQ6HdXV1S3WBkbv6JXqGhERxamiGPL/D8L5fxYvXhwNDQ2FiggAAIkz/AYAAAAAAAAAAAAAAAAAoANauHBhHHvssbFixYqsz0ydOjWmT58excXFeUxGNmpra6OxsbHF2tDo2+L10Zu9bm5ujqqqqrxnAwCAjsLwGwAAAAAAAAAAAAAAAACADmbWrFlx8sknx6pVq7LaX1JSEnfeeWdcdtllkUql8pyObFRUVLR4nYqIo6JPi7WB0Tt6RMtBRZWVlfmOBgAAHYbhNwAAAAAAAAAAAAAAAAAAHUQmk4mpU6fG+PHjI51OZ3WmtLQ0ampqoqysLM/pyFY6nY7q6uoWawOjd/RKdW2xVpwqiiGbDcRZvHhxNDQ05DsiAAB0CMXb3wIAAAAAAAAAAAAAAAAA7AwymUy88Nc347lnVsYrL70Tb65cE+n0pigp6RJ99tkjDjxkr/jwoH3i0IF9IpVKJR13l9PU1BTl5eUxc+bMrM8MGDAg5s6dG4MGDcpjMtqqtrY2GhsbW6wNjb5b3Xt09I0Ho/6fr5ubm6OqqirKy8vzmhEAADoCw28AAAAAAAAAAAAAAAAAYCe3ceOmqLvvhaid91zU/+O9re559aV34k8Pr4iIiH4f7BUnjf1wnHDyoVFS0qWQUXdZq1evjs997nMxb968rM8MHjw4/vCHP0S/fv3ymIwdUVFR0eJ1KiKOij5b3TswekePKI610fTPtcrKSsNvAADYJRh+AwAAAAAAAAAAAAAAAAA7sb89/2b84vqH4/XXVmV9pv619+LOXzwWD9Q8HxO+NDwOOewDeUxIfX19fOpTn4onnngi6zPjxo2LioqK2GOPPfKYjB2RTqejurq6xdrA6B29Ul23ur84VRRDMn3iwaj/59rixYujoaEh9t1333xGBQCAxBUlHQAAAAAAAAAAAAAAAAAAyI/auc/F979d06bBN//q9b+viqu/PT8emP98jpPx/zzzzDMxbNiwNg2+ufDCC+Oee+4x+KaDqq2tjcbGxhZrQ6Nvq2eO3qze3NwcVVVVOc8GAAAdjeE3AAAAAAAAAAAAAAAAALATWnDvs/GrGY9GpjnTrj7NzZm4/eZHYsG9z+YoGf/PwoUL49hjj40VK1ZkfWbq1Kkxffr0KC4uzmMy2qOioqLF61REHBV9Wj0zMHpHj2j577SysjLX0QAAoMMx/AYAAAAAAAAAAAAAAAAAdjJP//n1uPMXj+W056xbHotnnqzPac9d2axZs+Lkk0+OVat
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['SO2'].values, pred_rst['SO2'].values, label='$SO_2\\ (\\mu g/m^3)$', name='SO2')"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.800\n",
"Mean Absolute Error(MAE): 8.128\n",
"Mean squared error(MSE): 125.403\n",
"Root Mean Squard Error(RMSE): 11.198\n",
"R_squared: 0.800\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzde5yWdZ0//teFAx4QCAtTykO2tRjbFohGmBHlAdoyvx13f2AnNRRrt8NuJVbbtoGVHbaTBpQdRKuZmkhLRB0FitQOuJ3oZGp0YCx1IsFDM3L9/tCtBjkMzHXf1w08n4/HPNz787k+7+s1N7ezNsy87qIsywAAAAAAAAAAAAAAAAAAAAAAQDMNqTsAAAAAAAAAAAAAAAAAAAAAAAB7HuU3AAAAAAAAAAAAAAAAAAAAAAA0nfIbAAAAAAAAAAAAAAAAAAAAAACaTvkNAAAAAAAAAAAAAAAAAAAAAABNp/wGAAAAAAAAAAAAAAAAAAAAAICmU34DAAAAAAAAAAAAAAAAAAAAAEDTKb8BAAAAAAAAAAAAAAAAAAAAAKDplN8AAAAAAAAAAAAAAAAAAAAAANB0ym8AAAAAAAAAAAAAAAAAAAAAAGg65TcAAAAAAAAAAAAAAAAAAAAAADSd8hsAAAAAAAAAAAAAAAAAAAAAAJpO+Q0AAAAAAAAAAAAAAAAAAAAAAE2n/AYAAAAAAAAAAAAAAAAAAAAAgKZTfgMAAAAAAAAAAAAAAAAAAAAAQNMpvwEAAAAAAAAAAAAAAAAAAAAAoOmU3wAAAAAAAAAAAAAAAAAAAAAA0HTKbwAAAAAAAAAAAAAAAAAAAAAAaDrlNwAAAAAAAAAAAAAAAAAAAAAANJ3yGwAAAAAAAAAAAAAAAAAAAAAAmk75DQAAAAAAAAAAAAAAAAAAAAAATaf8BgAAAAAAAAAAAAAAAAAAAACAplN+AwAAAAAAAAAAAAAAAAAAAABA0ym/AQAAAAAAAAAAAAAAAAAAAACg6ZTfAAAAAAAAAAAAAAAAAAAAAADQdMpvAAAAAAAAAAAAAAAAAAAAAABoOuU3AAAAAAAAAAAAAAAAAAAAAAA0nfIbAAAAAAAAAAAAAAAAAAAAAACaTvkNAAAAAAAAAAAAAAAAAAAAAABNp/wGAAAAAAAAAAAAAAAAAAAAAICmU34DAAAAAAAAAAAAAAAAAAAAAEDTKb8BAAAAAAAAAAAAAAAAAAAAAKDplN8AAAAAAAAAAAAAAAAAAEATFUVR1J0BAABagfIbAAAAAAAAAAAAAAAAAABorpl1BwAAgFZQlGVZdwYAAAAAAAAAAAAAAAAAANgjFEUxNMnvkjy5LMvuuvMAAECdhtQdAAAAAAAAAAAAAAAAAAAA9iDPSfKoJC+sOwgAANRN+Q0AAAAAAAAAAAAAAAAAADTPSx/650tqTQEAAC2gKMuy7gwAAOyEoigOTvK8zZZvSXJPDXEAAAAAAAAAAAAAAIDWt1+SIzZb+1pZluvqCLMnKopiaJLbk4xOsinJY8qy7K43FQAA1Ket7gAAAOy05yVZWHcIAAAAAAAAAAAAAABgl/aaJIvqDrEHeU4eLL5JkiFJXpjkgvriAABAvYbUHQAAAAAAAAAAAAAAAAAAAPYQL93s8UtqSQEAAC2ire4AAAAAAAAAAAAAAAAAAABQobLuAFvS29ub0aNHp6en5y9rQ4YMeVZ3d3d50EEH1Zhsm4q6AwAAsHsbUncAAAAAAAAAAAAAAAAAAADY3XV1dfUrvkmSTZs2pbOzs6ZEAABQv7a6AwAAsNNu2XxhwYIFefKTn1xHFgAAAAAAAAAAAAAAoGZ9fX15//vfn8suu2xHjj3s9xNojPb29i2ud3R0ZM6cOU1OAwAAraEoy7LuDAAA7ISiKJ6e5Ft/u/atb30rT3/602tKBAAAAAAAAAAAAAAA1OXuu+/Oy172sixdunRHj04py/L6RmSqUcv98mxvb28e/ehHp6en52F7Q4YMyW9/+9scdNBBNSTbrqLuAAAA7N6G1B0AAAAAAAAAAAAAAAAAAICdt27dukydOnVnim9okq6uri0W3yTJpk2b0tnZ2eREAADQGpTfAAAAAAAAAAAAAAAAAADson784x9n8uTJuemmm+qOwja0t7dvc7+jo6NJSQAAoLUovwEAAAAAAAAAAAAAAAAA2AVdd911OfbYY7N27dq6o7ANvb29WbJkyTavWblyZbq7u5sTCAAAWojyGwAAAAAAAAAAAAAAAACAXcwll1ySk046KevXrx/Q9UOHDs1//ud/NjgVW9LV1ZWenp5+axMzpt/jTZs2pbOzs5mxAACgJSi/AQAAAAAAAAAAAAAAAADYRZRlmfnz52fWrFnp7e0d0JlRo0Zl2bJlOemkkxqcji1pb2/v97hIMjNPzPC09Vvv6OhoYioAAGgNym8AAAAAAAAAAAAAAAAAAHYBfX19mT17ds4999wBnzn00EOzatWqTJs2rYHJ2Jre3t4sWbKk39q4jM7oYu9MzJh+6ytXrkx3d3cT0wEAQP2U3wAAAAAAAAAAAAAAAAAAtLi77747J598chYtWjTgMxMmTMgNN9yQ8ePHNzAZ29LV1ZWenp5+a5NyYJLk6If++X82bdqUzs7OpmUDAIBWoPwGAAAAAAAAAAAAAAAAAKCFrVu3LlOnTs3SpUsHfGbGjBlZuXJlDj744AYmY3va29v7PS6SHJUxSZJxGZ3haeu339HR0axoAADQEpTfAAAAAAAAAAAAAAAAAAC0qB//+MeZPHlybrrppgGfOeOMM3LZZZdl//33b2Aytqe3tzdLlizptzYuozOyGJYkaSuGZOJDRTj/Z+XKlenu7m5WRAAAqJ3yGwAAAAAAAAAAAAAAAACAFnTdddfl2GOPzdq1awd8Zv78+VmwYEHa2toamIyB6OrqSk9PT7+1STmw3+OjN3u8adOmdHZ2NjwbAAC0CuU3AAAAAAAAAAAAAAAAAAAt5pJLLslJJ52U9evXD+j6oUOHZvHixTnnnHNSFEWD0zEQ7e3t/R4XSY7KmH5r4zI6w9O/qKijo6PR0QAAoGUovwEAAAAAAAAAAAAAAAAAaBFlWWb+/PmZNWtWent7B3Rm1KhRWbZsWWbOnNngdAxUb29vlixZ0m9tXEZnZDGs31pbMSQTNyvEWblyZbq7uxsdEQAAWoLyGwAAAAAAAAAAAAAAAACAFtDX15fZs2fn3HPPHfCZQw89NKtWrcq0adMamIwd1dXVlZ6enn5rk3LgFq89erP1TZs2pbOzs2HZAACglSi/AQAAAAAAAAAAAAAAAACo2d13352TTz45ixYtGvCZCRMm5IYbbsj48eMbmIyd0d7e3u9xkeSojNniteMyOsPT1m+to6OjUdEAAKClKL8BAAAAAAAAAAAAAAAAAKjRunXrMnXq1CxdunTAZ2bMmJGVK1fm4IMPbmAydkZvb2+WLFnSb21cRmdkMWyL17cVQzJxs2KclStXpru7u1ERAQCgZSi/AQAAAAAAAAAAAAAAAACoyY9//ONMnjw5N91004DPnHHGGbnsssuy//77NzAZO6urqys9PT391iblwG2eOXqz/U2bNqWzs7PybAAA0GqU3wAAAAAAAAAAAAAAAAAA1OC6667Lsccem7Vr1w74zPz587NgwYK0tbU1MBmD0d7e3u9xkeSojNnmmXEZneHp/2fa0dFRdTQAAGg5ym8AAAAAAAAAAAAAAAAAAJrskksuyUknnZT169cP6PqhQ4dm8eLFOeecc1IURYPTsbN6e3uzZMmSfmvjMjoji2HbPNdWDMnEzQpyVq5cme7u7qojAgBAS1F+AwAAAAAAAAAAAAAAAADQJGVZZv78+Zk1a1Z6e3sHdGbUqFFZtmxZZs6c2eB0DFZXV1d6enr6rU3KgQM6e/Rm123atCmdnZ2VZQMAgFak/AYAAAAAAAAAAAAAAAAAoAn6+voye/bsnHvuuQM+c+ihh2bVqlWZNm1aA5NRlfb29n6PiyRHZcyAzo7L6AxPW7+1jo6OqqIBAEBLUn4DAAAAAAAAAAAAAAAAANBgd999d04++eQsWrRowGcmTJiQG26
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['NO2'].values, pred_rst['NO2'].values, label='$NO_2\\ (\\mu g/m^3)$', name='NO2')"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.944\n",
"Mean Absolute Error(MAE): 8.725\n",
"Mean squared error(MSE): 138.808\n",
"Root Mean Squard Error(RMSE): 11.782\n",
"R_squared: 0.944\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzde5jVZbk//nsNMyAgjIOCiIqndGOkBaIhbkUUEMzTdnfYfUHNREGs3dkSrawELTvsLA9AaqZSe6amUQwYcYTBUNTCMsUTopLGeBxHOckMs35/7N92N8AMa5i11mcNvF7XNVfX57mf537eMxAFs+ZeqXQ6HQAAAAAAAAAAAAAAAAAAAAAAkE9FSQcAAAAAAAAAAAAAAAAAAAAAAGDXY/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAAAAAAB5Z/gNAAAAAAAAAAAAAAAAAADkUSqVSiWdAQAACoHhNwAAAAAAAAAAAAAAAAAAkF8Tkg4AAACFIJVOp5POAAAAAAAAAAAAAAAAAAAAu4RUKlUSEf+IiCPS6XRd0nkAACBJRUkHAAAAAAAAAAAAAAAAAACAXcjJEbFXRJyddBAAAEia4TcAAAAAAAAAAAAAAAAAAJA/n/z///MTiaYAAIACkEqn00lnAABgB6RSqX0i4rQtlldFxPoE4gAAAAAAAAAAAAAAAIWvR0QcvMXaPel0ek0SYXZFqVSqJCJejYiyiGiOiH3T6XRdsqkAACA5xUkHAABgh50WEbOSDgEAAAAAAAAAAAAAAHRqF0XE7KRD7EJOjv8ZfBMRURQRZ0fEDcnFAQCAZBUlHQAAAAAAAAAAAAAAAAAAAHYRn9zi+ROJpAAAgAJRnHQAAAAAAAAAAAAAAAAAAADIonTSAbalsbExysrKor6+/v21oqKiE+vq6tL9+/dPMFmbUkkHAABg51aUdAAAAAAAAAAAAAAAAAAAANjZ1dTUtBh8ExHR3NwclZWVCSUCAIDkFScdAACAHbZqy4WZM2fGEUcckUQWAAAAAAAAAAAAAAAgYU1NTfHDH/4w7r777vYc2+rnE8iN8vLyba5XVFTE1KlT85wGAAAKQyqdTiedAQCAHZBKpY6NiAf/ee3BBx+MY489NqFEAAAAAAAAAAAAAABAUt5999341Kc+FfPnz2/v0RHpdPqhXGRKUMH98GxjY2PsvffeUV9fv1WtqKgoXnnllejfv38CybYrlXQAAAB2bkVJBwAAAAAAAAAAAAAAAAAAYMetWbMmRo4cuSODb8iTmpqabQ6+iYhobm6OysrKPCcCAIDCYPgNAAAAAAAAAAAAAAAAAEAn9eSTT8bw4cPjscceSzoKbSgvL2+zXlFRkackAABQWAy/AQAAAAAAAAAAAAAAAADohBYtWhTHHXdcrF69OukotKGxsTGqqqra3LNkyZKoq6vLTyAAACgght8AAAAAAAAAAAAAAAAAAHQyd955Z5xyyinR0NCQ0f6SkpL49re/neNUbEtNTU3U19e3WBsafVs8Nzc3R2VlZT5jAQBAQTD8BgAAAAAAAAAAAAAAAACgk0in0zFjxoyYOHFiNDY2ZnSmtLQ0qqur45RTTslxOralvLy8xXMqIibEYdEzilusV1RU5DEVAAAUBsNvAAAAAAAAAAAAAAAAAAA6gaamppg8eXJcfvnlGZ8ZOHBgLF26NEaNGpXDZLSmsbExqqqqWqwNirIoS3WLodG3xfqSJUuirq4uj+kAACB5ht8AAAAAAAAAAAAAAAAAABS4d999N84444yYPXt2xmeGDBkSy5Yti8GDB+cwGW2pqamJ+vr6FmvDol9ERBz9///n/2pubo7Kysq8ZQMAgEJg+A0AAAAAAAAAAAAAAAAAQAFbs2ZNjBw5MubPn5/xmfHjx8eSJUtin332yWEytqe8vLzFcyoijoq+ERExKMqiZxS3qFdUVOQrGgAAFATDbwAAAAAAAAAAAAAAAAAACtSTTz4Zw4cPj8ceeyzjMxdeeGHcfffdsfvuu+cwGdvT2NgYVVVVLdYGRVn0TnWNiIjiVFEM/f8H4fyvJUuWRF1dXb4iAgBA4gy/AQAAAAAAAAAAAAAAAAAoQIsWLYrjjjsuVq9enfGZGTNmxMyZM6O4uDiHychETU1N1NfXt1gbFv1aPB+9xXNzc3NUVlbmPBsAABQKw28AAAAAAAAAAAAAAAAAAArMnXfeGaeccko0NDRktL+kpCTuuOOOuOyyyyKVSuU4HZkoLy9v8ZyKiKOib4u1QVEWPaPloKKKiopcRwMAgIJh+A0AAAAAAAAAAAAAAAAAQIFIp9MxY8aMmDhxYjQ2NmZ0prS0NKqrq2PChAk5TkemGhsbo6qqqsXaoCiL3qmuLdaKU0UxdIuBOEuWLIm6urpcRwQAgIJg+A0AAAAAAAAAAAAAAAAAQAFoamqKyZMnx+WXX57xmYEDB8bSpUtj1KhROUxGe9XU1ER9fX2LtWHRb5t7j95ivbm5OSorK3OWDQAAConhNwAAAAAAAAAAAAAAAAAACXv33XfjjDPOiNmzZ2d8ZsiQIbFs2bIYPHhwDpOxI8rLy1s8pyLiqOi7zb2Doix6RnGLtYqKilxFAwCAgmL4DQAAAAAAAAAAAAAAAABAgtasWRMjR46M+fPnZ3xm/PjxsWTJkthnn31ymIwd0djYGFVVVS3WBkVZ9E513eb+4lRRDN1iMM6SJUuirq4uVxEBAKBgGH4DAAAAAAAAAAAAAAAAAJCQJ598MoYPHx6PPfZYxmcuvPDCuPvuu2P33XfPYTJ2VE1NTdTX17dYGxb92jxz9Bb15ubmqKyszHo2AAAoNIbfAAAAAAAAAAAAAAAAAAAkYNGiRXHcccfF6tWrMz4zY8aMmDlzZhQXF+cwGR1RXl7e4jkVEUdF3zbPDIqy6Bktf00rKiqyHQ0AAAqO4TcAAAAAAAAAAAAAAAAAAHl25513ximnnBINDQ0Z7S8pKYk77rgjLrvsskilUjlOx45qbGyMqqqqFmuDoix6p7q2ea44VRRDtxiQs2TJkqirq8t2RAAAKCiG3wAAAAAAAAAAAAAAAAAA5Ek6nY4ZM2bExIkTo7GxMaMzpaWlUV1dHRMmTMhxOjqqpqYm6uvrW6wNi34ZnT16i33Nzc1RWVmZtWwAAFCIDL8BAAAAAAAAAAAAAAAAAMiDpqammDx5clx++eUZnxk4cGAsXbo0Ro0alcNkZEt5eXmL51REHBV9Mzo7KMqiZxS3WKuoqMhWNAAAKEiG3wAAAAAAAAAAAAAAAAAA5Ni
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['O3'], pred_rst['O3'], label='$O_3 \\ (\\mu g/m^3)$', name='O3')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"def scatter_out_2(x, y, name): ## x,y为两个需要做对比分析的两个量。\n",
" # ==========计算评价指标==========\n",
" BIAS = mean(x - y)\n",
" MSE = mean_squared_error(x, y)\n",
" RMSE = np.power(MSE, 0.5)\n",
" R2 = r2_score(x, y)\n",
" MAE = mean_absolute_error(x, y)\n",
" EV = explained_variance_score(x, y)\n",
" print('==========算法评价指标==========')\n",
" print('Explained Variance(EV):', '%.3f' % (EV))\n",
" print('Mean Absolute Error(MAE):', '%.3f' % (MAE))\n",
" print('Mean squared error(MSE):', '%.3f' % (MSE))\n",
" print('Root Mean Squard Error(RMSE):', '%.3f' % (RMSE))\n",
" print('R_squared:', '%.3f' % (R2))\n",
" # ===========Calculate the point density==========\n",
" xy = np.vstack([x, y])\n",
" z = stats.gaussian_kde(xy)(xy)\n",
" # ===========Sort the points by density, so that the densest points are plotted last===========\n",
" idx = z.argsort()\n",
" x, y, z = x[idx], y[idx], z[idx]\n",
" def best_fit_slope_and_intercept(xs, ys):\n",
" m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs)))\n",
" b = mean(ys) - m * mean(xs)\n",
" return m, b\n",
" m, b = best_fit_slope_and_intercept(x, y)\n",
" regression_line = []\n",
" for a in x:\n",
" regression_line.append((m * a) + b)\n",
" fig,ax=plt.subplots(figsize=(12,9),dpi=400)\n",
" scatter=ax.scatter(x,y,marker='o',c=z*100,s=15,label='LST',cmap='Spectral_r')\n",
" cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='frequency')\n",
"\n",
" plt.plot([0, 6], [0, 6],'black',lw=1.5) # 画的1:1线线的颜色为black线宽为0.8\n",
" plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线\n",
" plt.axis([0,6,0,6]) # 设置线的范围\n",
" plt.xlabel(f'Measured {name}')\n",
" plt.ylabel(f'Retrived {name}')\n",
" # plt.xticks(fontproperties='Times New Roman')\n",
" # plt.yticks(fontproperties='Times New Roman')\n",
"\n",
"\n",
" plt.text(0.3, 5.5, '$N=%.f$' % len(y)) # text的位置需要根据x,y的大小范围进行调整。\n",
" plt.text(0.3, 5.0, '$R^2=%.2f$' % R2)\n",
" plt.text(0.3, 4.5, '$RMSE=%.2f$' % RMSE)\n",
" plt.xlim(0, 6) # 设置x坐标轴的显示范围\n",
" plt.ylim(0, 6) # 设置y坐标轴的显示范围\n",
" file_name = name.split('(')[0].strip()\n",
" plt.savefig(f'./figure/looknow/CO.png',dpi=800,bbox_inches='tight',pad_inches=0)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.810\n",
"Mean Absolute Error(MAE): 0.227\n",
"Mean squared error(MSE): 0.095\n",
"Root Mean Squard Error(RMSE): 0.309\n",
"R_squared: 0.810\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEKQAAAzZCAYAAADHXzXhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOz9e5yXdZ0//j8ucFGcScmKAgXTDrspFuuiu0Nr5qBhtfbzW+AeUCrpaAcNt7a20JVt+1SWlG6pKYWdG6TFLUCztEIJSqcjU5EyaqcVRbN0lHS8fn/oto4M8Ia53nPNDPf77fa+tTyv6/W8Hm9iJxHejynKsgwAAAAAAAAAAAAAAAAAAAAA/K9RdQcAAAAAAAAAAAAAAAAAAAAAYGhRSAEAAAAAAAAAAAAAAAAAAABAHwopAAAAAAAAAAAAAAAAAAAAAOhDIQUAAAAAAAAAAAAAAAAAAAAAfSikAAAAAAAAAAAAAAAAAAAAAKAPhRQAAAAAAAAAAAAAAAAAAAAA9KGQAgAAAAAAAAAAAAAAAAAAAIA+FFIAAAAAAAAAAAAAAAAAAAAA0IdCCgAAAAAAAAAAAAAAAAAAAAD6UEgBAAAAAAAAAAAAAAAAAAAAQB8KKQAAAAAAAAAAAAAAAAAAAADoQyEFAAAAAAAAAAAAAAAAAAAAAH0opAAAAAAAAAAAAAAAAAAAAACgD4UUAAAAAAAAAAAAAAAAAAAAAPShkAIAAAAAAAAAAAAAAAAAAACAPhRSAAAAAAAAAAAAAAAAAAAAANCHQgoAAAAAAAAAAAAAAAAAAAAA+lBIAQAAAAAAAAAAAAAAAAAAAEAfCikAAAAAAAAAAAAAAAAAAAAA6EMhBQAAAAAAAAAAAAAAAAAAAAB9KKQAAAAAAAAAAAAAAAAAAAAAoA+FFAAAAAAAAAAAAAAAAAAAAAD0oZACAAAAAAAAAAAAAAAAAAAAgD4UUgAAAAAAAAAAAAAAAAAAAADQh0IKAAAAAAAAAAAAAAAAAAAAAPpQSAEAAAAAAAAAAAAAAAAAAABAHwopAAAAAAAAAAAAAAAAAAAAAOhDIQUAAAAAAAAAAAAAAAAAAAAAfSikAAAAAAAAAAAAAAAAAAAAAKAPhRQAAAAAAAAAAAAAAAAAAAAA9KGQAgAAAAAAAAAAAAAAAAAAAIA+FFIAAAAAAAAAAAAAAAAAAAAA0IdCCgAAAAAAAAAAAAAAAAAAAAD6UEgBAAAAAAAAAAAAAAAAAAAAQB8KKQAAAAAAAAAAAAAAAAAAAADoQyEFAAAAAAAAAAAAAAAAAAAAAH0opAAAAAAAAAAAAAAAAAAAAACgD4UUAAAAAAAAAAAAAAAAAAAAAPShkAIAAAAAAAAAAAAAAAAAAACAPhRSAAAAAAAAAAAAAAAAAAAAANCHQgoAAAAAAAAAAAAAAAAAAAAA+lBIAQAAAAAAAAAAAAAAAAAAAEAfCikAAAAAAAAAAAAAAAAAAAAA6GOPugMADFRRFC1JpiQ5JMlBSZ6eZGKSJyXZL8k+ScYk2TNJkWTLo6/7k9yVZPOjr18lueXR14YkPy/L8sFBeyMAAAAAAAAAAAAAAAAAAABDRFGWZd0ZAHZKURQHJzkmyVFJpid5Zh4pmqjag3mkmOJHSdYlWZvk+2VZ/rEJzwIAAAAAAAAAAAAAAAAAABgyFFIAw0JRFNOSnJTk75I8p8YoW5J8J8k3Hn19ryzLh2rMAwAAAAAAAAAAAAAAAAAAUDmFFMCQVRTFk5O8Jskrk/xFzXG25XdJViX57yRXlmX5u1rTAAAAAAAAAAAAAAAAAAAAVEAhBTDkFEXx50nenmROkr1qjrMzHkzy9SRfSvJfZVn+vuY8AAAAAAAAAAAAAAAAAAAAu0QhBTBkFEXxjCRnJ/mnJKNrjjNQW5KcWpbl5+sOAgAAAAAAAAAAAAAAAAAAsLP2qDsAQFEUrUnek+RtScbUHKcqeyaZWHcIAAAAAAAAAAAAAAAAAACAXaGQAqhVURR/l+TiKG8AAAAAAAAAAAAAAAAAAAAYMkbVHQDYPRVFsU9RFJcl+UqUUQAAAAAAAAAAAAAAAAAAAAwpe9QdANj9FEUxNcnSJM+sOQoAAAAAAAAAAAAAAAAAAAD9GFV3AGD3UhTFKUm+E2UUAAAAAAAAAAAAAAAAAAAAQ5ZCCmDQFEVxTpJPJ9mr7iwAAAAAAAAAAAAAAAAAAABs2x51BwBGvqIoRiW5NMmr684CAAAAAAAAAAAAAAAAAADAjimkAJqqKIo9knw2yd/XnQUAAAAAAAAAAAAAAAAAAIDGjKo7ADByFUUxKsooAAAAAAAAAAAAAAAAAAAAhp096g4AjGgXZnDLKP6Y5IYkP01y86OvjUnuSXLfo697k4xOstejr32SPC3JU5NMTPLsJH+R5M+TTBrE7AAAAAAAAAAAAAAAAAAAAEOGQgqgKYqi+LckrxuER92Y5KtJvpVkbVmW9zdwpjePlFf8PsmmJDf1d1NRFE9J8tdJjkzyt0men2RMBZkBAAAAAAAAAAAAAAAAAACGNIUUQOWKovj7JGc38RG/SfK5JJeVZbm+WQ8py/KOPFJ28dUkKYpi7yQvTDIzyf+XZFKzng0AAAAAAAAAAAAAAAAAAFCnoizLujMAI0hRFM9LsjbJXk1Y/9sk70vyibIs/9iE/Q0riqJI8tdJZiX5pyQT+rnt7WVZfmhQgwEAAAAAAAAAAAAAAAAAAFRgVN0BgJGjKIqxSb6Q6sso7kvy9iQHl2X5n3WXUSRJ+Yi1ZVn+c5JJSV6W5IokD9WbDAAAAAAAAAAAAAAAAAAAYOAUUgBV+nCS51S8c02SqWVZfqgsywcq3l2Jsix7y7L8SlmWJyY5OMm5SX5XaygAAAAAAAAAAAAAAAAAAIABKMqyrDsDMAIURXFCkv+ucOXDSRYkeX9Zlg9XuHdQFEXRmuSpZVneXHcWAAAAAAAAAAAAAAAAAACAnaWQAhiwoiieluTHSZ5c0coHkpxSluXlFe0DAAAAAAAAAAAAAAAAAABgJ+xRdwBgRPhYqiujuDvJy8qyvK6ifQAAAAAAAAAAAAAAAAAAAOykoizLujMAw1hRFC9Mcm1F63qStJdlua6ifQAAAAAAAAAAAAAAAAAAAOyCUXUHAIavoihGJflIRet6k/yDMgoAAAAAAAAAAAAAAAAAAID6KaQABmJekudVtOutZVl+paJdAAAAAAAAAAAAAAAAAAAADEBRlmXdGYBhqCiKfZL8Isn4CtZ9uSzLV1SwBwAAAAAAAAAAAAAAAAAAgAqMqjsAMGz9c6opo/htktdVsAcAAAAAAAAAAAAAAAAAAICKFGVZ1p0BGGaKotg7yW1JnlTBuheXZXllBXsAAAAAAAAAAAAAAAAAAACoyKi6AwDD0qtTTRnFFcooAAAAAAAAAAAAAAAAAAAAhp6iLMu6MwDDSFEUo5L8PMkzB7jqoSSHlmW5YeCpAAAAAAAAAAAAAAAAAAAAqNKougMAw86JGXgZRZJcpIwCAAAAAAAAAAAAAAAAAABgaCrKsqw7AzCMFEVxfZLpA1yzJcmBZVneXkEkAAAAAAAAAAAAAAAAAAAAKjaq7gDA8FEUxZ9n4GUUSfIFZRQAAAAAAAAAAAAAAAAAAABDl0IKYGf8Y0V7PlLRHgAAAAAAAAAAAAAAAAAAAJpAIQWwM/6+gh3fLMvyhxXsAQAAAAAAAAAAAAAAAAAAoEkUUgANKYpiapK/qGDVJRXsAAAAAAAAAAAAAAAAAAAAoIkUUgCN+ocKdjyQ5CsV7AEAAAAAAAAAAAAAAAAAAKCJFFIAjTqpgh0ry7L8QwV7AAAAAAAAAAAAAAAAAAAAaCKFFMAOFUXx50kOqmDVlyrYAQAAAAAAAAAAAAAAAAAAQJMppAAacUwFOx5MsqKCPQAAAAAAAAAAAAAAAAAAADSZQgqgES+sYMfasizvq2APAAAAAAAAAAAAAAAAAAAATaaQAmjECyvYcU0FOwAAAAAAAAAAAAAAAAAAABgECimA7SqK4jlJnlrBqm9UsAMAAAAAAAAAAAAAAAAAAIBBoJAC2JEXVrCjJ8m6CvYAAAA
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_2(real_rst['CO'].values, pred_rst['CO'].values, name='$CO \\ (mg/m^3)$')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py37",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "e35e91facd2b4cfa08991d112893a00c4d14d1c91c990d1b62f3056d14d2f283"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}