22-T67/keras_multi-attention_multi...

1646 lines
6.3 MiB
Plaintext
Raw Normal View History

2023-03-30 10:25:44 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_DEVICE_ORDER'] = 'PCB_BUS_ID'\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt\n",
"#新增加的两行\n",
"from pylab import mpl\n",
"# 设置显示中文字体\n",
"mpl.rcParams[\"font.sans-serif\"] = [\"SimHei\"]\n",
"\n",
"mpl.rcParams[\"axes.unicode_minus\"] = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>24_PM2.5</th>\n",
" <th>24_PM10</th>\n",
" <th>24_SO2</th>\n",
" <th>24_NO2</th>\n",
" <th>24_O3</th>\n",
" <th>24_CO</th>\n",
" <th>23_PM2.5</th>\n",
" <th>23_PM10</th>\n",
" <th>23_SO2</th>\n",
" <th>23_NO2</th>\n",
" <th>...</th>\n",
" <th>NH3_resdient</th>\n",
" <th>NH3_agricultural</th>\n",
" <th>VOC_industrial</th>\n",
" <th>VOC_transportation</th>\n",
" <th>VOC_resdient</th>\n",
" <th>VOC_power</th>\n",
" <th>PM2.5_industrial</th>\n",
" <th>PM2.5_transportation</th>\n",
" <th>PM2.5_resdient</th>\n",
" <th>PM2.5_power</th>\n",
" </tr>\n",
" <tr>\n",
" <th>date</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2015-01-03 01:00:00</th>\n",
" <td>136.0</td>\n",
" <td>214.0</td>\n",
" <td>317.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>3.71</td>\n",
" <td>114.0</td>\n",
" <td>176.0</td>\n",
" <td>305.0</td>\n",
" <td>38.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.359273</td>\n",
" <td>1.177423</td>\n",
" <td>1.084925</td>\n",
" <td>0.937173</td>\n",
" <td>0.037724</td>\n",
" <td>0.926851</td>\n",
" <td>0.077715</td>\n",
" <td>0.827110</td>\n",
" <td>0.436028</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 02:00:00</th>\n",
" <td>114.0</td>\n",
" <td>176.0</td>\n",
" <td>305.0</td>\n",
" <td>38.0</td>\n",
" <td>8.0</td>\n",
" <td>3.55</td>\n",
" <td>97.0</td>\n",
" <td>154.0</td>\n",
" <td>306.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.359273</td>\n",
" <td>1.177423</td>\n",
" <td>1.134240</td>\n",
" <td>0.937173</td>\n",
" <td>0.036215</td>\n",
" <td>0.926851</td>\n",
" <td>0.081248</td>\n",
" <td>0.827110</td>\n",
" <td>0.418587</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 03:00:00</th>\n",
" <td>97.0</td>\n",
" <td>154.0</td>\n",
" <td>306.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>3.51</td>\n",
" <td>87.0</td>\n",
" <td>141.0</td>\n",
" <td>316.0</td>\n",
" <td>38.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.327791</td>\n",
" <td>1.177423</td>\n",
" <td>1.232869</td>\n",
" <td>0.937173</td>\n",
" <td>0.035712</td>\n",
" <td>0.926851</td>\n",
" <td>0.088313</td>\n",
" <td>0.827110</td>\n",
" <td>0.412773</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 04:00:00</th>\n",
" <td>87.0</td>\n",
" <td>141.0</td>\n",
" <td>316.0</td>\n",
" <td>38.0</td>\n",
" <td>7.0</td>\n",
" <td>3.55</td>\n",
" <td>85.0</td>\n",
" <td>139.0</td>\n",
" <td>292.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.033910</td>\n",
" <td>0.350014</td>\n",
" <td>1.177423</td>\n",
" <td>1.273965</td>\n",
" <td>0.937173</td>\n",
" <td>0.036718</td>\n",
" <td>0.926851</td>\n",
" <td>0.091256</td>\n",
" <td>0.827110</td>\n",
" <td>0.424400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2015-01-03 05:00:00</th>\n",
" <td>85.0</td>\n",
" <td>139.0</td>\n",
" <td>292.0</td>\n",
" <td>37.0</td>\n",
" <td>7.0</td>\n",
" <td>3.62</td>\n",
" <td>106.0</td>\n",
" <td>167.0</td>\n",
" <td>316.0</td>\n",
" <td>37.0</td>\n",
" <td>...</td>\n",
" <td>0.071588</td>\n",
" <td>0.388904</td>\n",
" <td>1.177423</td>\n",
" <td>1.290403</td>\n",
" <td>1.978475</td>\n",
" <td>0.039736</td>\n",
" <td>0.926851</td>\n",
" <td>0.092434</td>\n",
" <td>1.746121</td>\n",
" <td>0.459282</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 187 columns</p>\n",
"</div>"
],
"text/plain": [
" 24_PM2.5 24_PM10 24_SO2 24_NO2 24_O3 24_CO \\\n",
"date \n",
"2015-01-03 01:00:00 136.0 214.0 317.0 38.0 8.0 3.71 \n",
"2015-01-03 02:00:00 114.0 176.0 305.0 38.0 8.0 3.55 \n",
"2015-01-03 03:00:00 97.0 154.0 306.0 37.0 7.0 3.51 \n",
"2015-01-03 04:00:00 87.0 141.0 316.0 38.0 7.0 3.55 \n",
"2015-01-03 05:00:00 85.0 139.0 292.0 37.0 7.0 3.62 \n",
"\n",
" 23_PM2.5 23_PM10 23_SO2 23_NO2 ... NH3_resdient \\\n",
"date ... \n",
"2015-01-03 01:00:00 114.0 176.0 305.0 38.0 ... 0.033910 \n",
"2015-01-03 02:00:00 97.0 154.0 306.0 37.0 ... 0.033910 \n",
"2015-01-03 03:00:00 87.0 141.0 316.0 38.0 ... 0.033910 \n",
"2015-01-03 04:00:00 85.0 139.0 292.0 37.0 ... 0.033910 \n",
"2015-01-03 05:00:00 106.0 167.0 316.0 37.0 ... 0.071588 \n",
"\n",
" NH3_agricultural VOC_industrial VOC_transportation \\\n",
"date \n",
"2015-01-03 01:00:00 0.359273 1.177423 1.084925 \n",
"2015-01-03 02:00:00 0.359273 1.177423 1.134240 \n",
"2015-01-03 03:00:00 0.327791 1.177423 1.232869 \n",
"2015-01-03 04:00:00 0.350014 1.177423 1.273965 \n",
"2015-01-03 05:00:00 0.388904 1.177423 1.290403 \n",
"\n",
" VOC_resdient VOC_power PM2.5_industrial \\\n",
"date \n",
"2015-01-03 01:00:00 0.937173 0.037724 0.926851 \n",
"2015-01-03 02:00:00 0.937173 0.036215 0.926851 \n",
"2015-01-03 03:00:00 0.937173 0.035712 0.926851 \n",
"2015-01-03 04:00:00 0.937173 0.036718 0.926851 \n",
"2015-01-03 05:00:00 1.978475 0.039736 0.926851 \n",
"\n",
" PM2.5_transportation PM2.5_resdient PM2.5_power \n",
"date \n",
"2015-01-03 01:00:00 0.077715 0.827110 0.436028 \n",
"2015-01-03 02:00:00 0.081248 0.827110 0.418587 \n",
"2015-01-03 03:00:00 0.088313 0.827110 0.412773 \n",
"2015-01-03 04:00:00 0.091256 0.827110 0.424400 \n",
"2015-01-03 05:00:00 0.092434 1.746121 0.459282 \n",
"\n",
"[5 rows x 187 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('./new_train_data.csv', index_col='date')\n",
"# data.drop(columns=['wd'], inplace=True) # 风向还没想好怎么处理\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(181, 6)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_cols = ['PM2.5', 'PM10', 'SO2', 'NO2', 'O3', 'CO']\n",
"feature_cols = [x for x in data.columns if x not in out_cols and x != 'date']\n",
"len(feature_cols), len(out_cols)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 08:54:01.213692: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow.keras.backend as K"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class TransformerBlock(layers.Layer):\n",
" def __init__(self, embed_dim, num_heads, ff_dim, name, rate=0.1):\n",
" super().__init__()\n",
" self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, name=name)\n",
" self.ffn = keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" )\n",
" self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.dropout1 = layers.Dropout(rate)\n",
" self.dropout2 = layers.Dropout(rate)\n",
"\n",
" def call(self, inputs, training):\n",
" attn_output = self.att(inputs, inputs)\n",
" attn_output = self.dropout1(attn_output, training=training)\n",
" out1 = self.layernorm1(inputs + attn_output)\n",
" ffn_output = self.ffn(out1)\n",
" ffn_output = self.dropout2(ffn_output, training=training)\n",
" return self.layernorm2(out1 + ffn_output)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.initializers import Constant"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Custom loss layer\n",
"class CustomMultiLossLayer(layers.Layer):\n",
" def __init__(self, nb_outputs=2, **kwargs):\n",
" self.nb_outputs = nb_outputs\n",
" self.is_placeholder = True\n",
" super(CustomMultiLossLayer, self).__init__(**kwargs)\n",
" \n",
" def build(self, input_shape=None):\n",
" # initialise log_vars\n",
" self.log_vars = []\n",
" for i in range(self.nb_outputs):\n",
" self.log_vars += [self.add_weight(name='log_var' + str(i), shape=(1,),\n",
" initializer=tf.initializers.he_normal(), trainable=True)]\n",
" super(CustomMultiLossLayer, self).build(input_shape)\n",
"\n",
" def multi_loss(self, ys_true, ys_pred):\n",
" assert len(ys_true) == self.nb_outputs and len(ys_pred) == self.nb_outputs\n",
" loss = 0\n",
" for y_true, y_pred, log_var in zip(ys_true, ys_pred, self.log_vars):\n",
" mse = (y_true - y_pred) ** 2.\n",
" pre = K.exp(-log_var[0])\n",
" loss += tf.abs(tf.reduce_logsumexp(pre * mse + log_var[0], axis=-1))\n",
" return K.mean(loss)\n",
"\n",
" def call(self, inputs):\n",
" ys_true = inputs[:self.nb_outputs]\n",
" ys_pred = inputs[self.nb_outputs:]\n",
" loss = self.multi_loss(ys_true, ys_pred)\n",
" self.add_loss(loss, inputs=inputs)\n",
" # We won't actually use the output.\n",
" return K.concatenate(inputs, -1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"num_heads, ff_dim = 3, 16"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def get_prediction_model():\n",
" def build_output(out, out_name):\n",
" self_block = TransformerBlock(64, num_heads, ff_dim, name=f'{out_name}_attn')\n",
" out = self_block(out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(32, activation=\"relu\")(out)\n",
" # out = layers.Dense(1, name=out_name, activation=\"sigmoid\")(out)\n",
" return out\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" x = layers.Conv1D(filters=64, kernel_size=1, activation='relu')(inputs)\n",
" # x = layers.Dropout(rate=0.1)(x)\n",
" lstm_out = layers.Bidirectional(layers.LSTM(units=64, return_sequences=True))(x)\n",
" lstm_out = layers.Dense(128, activation='relu')(lstm_out)\n",
" transformer_block = TransformerBlock(128, num_heads, ff_dim, name='first_attn')\n",
" out = transformer_block(lstm_out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(64, activation='relu')(out)\n",
" out = K.expand_dims(out, axis=1)\n",
"\n",
" pm25 = build_output(out, 'pm25')\n",
" pm10 = build_output(out, 'pm10')\n",
" so2 = build_output(out, 'so2')\n",
" no2 = build_output(out, 'no2')\n",
" o3 = build_output(out, 'o3')\n",
" co = build_output(out, 'co')\n",
"\n",
" merge = layers.Concatenate(axis=1)([pm25, pm10, so2, no2, o3, co])\n",
" merge = K.expand_dims(merge, axis=1)\n",
" merge_attn = TransformerBlock(32*6, 3, 16, name='last_attn')\n",
"\n",
" out = merge_attn(merge)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
"\n",
" pm25 = layers.Dense(32, activation='relu')(out)\n",
" pm10 = layers.Dense(32, activation='relu')(out)\n",
" so2 = layers.Dense(32, activation='relu')(out)\n",
" no2 = layers.Dense(32, activation='relu')(out)\n",
" o3 = layers.Dense(32, activation='relu')(out)\n",
" co = layers.Dense(32, activation='relu')(out)\n",
"\n",
" pm25 = layers.Dense(1, activation='sigmoid', name='pm25')(pm25)\n",
" pm10 = layers.Dense(1, activation='sigmoid', name='pm10')(pm10)\n",
" so2 = layers.Dense(1, activation='sigmoid', name='so2')(so2)\n",
" no2 = layers.Dense(1, activation='sigmoid', name='no2')(no2)\n",
" o3 = layers.Dense(1, activation='sigmoid', name='o3')(o3)\n",
" co = layers.Dense(1, activation='sigmoid', name='co')(co)\n",
"\n",
" model = Model(inputs=[inputs], outputs=[pm25, pm10, so2, no2, o3, co])\n",
" return model\n",
"\n",
"def get_trainable_model(prediction_model):\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" pm25, pm10, so2, no2, o3, co = prediction_model(inputs)\n",
" pm25_real = layers.Input(shape=(1,), name='pm25_real')\n",
" pm10_real = layers.Input(shape=(1,), name='pm10_real')\n",
" so2_real = layers.Input(shape=(1,), name='so2_real')\n",
" no2_real = layers.Input(shape=(1,), name='no2_real')\n",
" o3_real = layers.Input(shape=(1,), name='o3_real')\n",
" co_real = layers.Input(shape=(1,), name='co_real')\n",
" out = CustomMultiLossLayer(nb_outputs=6)([pm25_real, pm10_real, so2_real, no2_real, o3_real, co_real, pm25, pm10, so2, no2, o3, co])\n",
" return Model([inputs, pm25_real, pm10_real, so2_real, no2_real, o3_real, co_real], out)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"use_cols = feature_cols + out_cols\n",
"use_data = data[use_cols].dropna()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"for col in use_cols:\n",
" use_data[col] = use_data[col].astype('float32')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"maxs = use_data.max()\n",
"mins = use_data.min()\n",
"use_cols = use_data.columns"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"for col in use_cols:\n",
" # use_data[col] = use_data[col].apply(lambda x: 0 if x < 0 else x)\n",
" # use_data[col] = np.log1p(use_data[col])\n",
" use_data[col] = (use_data[col] - mins[col]) / (maxs[col] - mins[col])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid = train_test_split(use_data[use_cols], test_size=0.1, random_state=42, shuffle=True)\n",
"valid_data, test_data = train_test_split(valid, test_size=0.5, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 08:54:16.887997: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1\n",
"2023-03-30 08:54:16.964653: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal\n",
"2023-03-30 08:54:16.964687: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: ubuntu-NF5468M6\n",
"2023-03-30 08:54:16.964693: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: ubuntu-NF5468M6\n",
"2023-03-30 08:54:16.964801: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 510.47.3\n",
"2023-03-30 08:54:16.964823: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 510.47.3\n",
"2023-03-30 08:54:16.964828: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 510.47.3\n",
"2023-03-30 08:54:16.965147: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"prediction_model = get_prediction_model()\n",
"trainable_model = get_trainable_model(prediction_model)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import optimizers\n",
"from tensorflow.python.keras.utils.vis_utils import plot_model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAADRsAAAd1CAYAAACxGen2AAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nOzdebiVVaE/8O97OIdZAY1BLjgkKGlizhzMRMXpSo7MMnlzKCsFQrPSm9c0GyDAyqFrhaUGQob3ZtLVMk2UckjrZ5pKUaGIIOGsHDj79wcbBEQFhLNBPp/n2c8j613v2t+1ePxrP19WUSqVSgEAAAAAAAAAAAAAAAC2dlOrKp0AAAAAAAAAAAAAAAAA2DwoGwEAAAAAAAAAAAAAAABJlI0AAAAAAAAAAAAAAACAMmUjAAAAAAAAAAAAAAAAIElSXekAAAAAAABA5fXr16/SEaDB1dbWZvTo0ZWOAQAAAAAAsFlRNgIAAAAAADJt2rT06NEjnTp1qnQUaBCzZs2qdAQAAAAAAIDNkrIRAAAAAACQJBk1alT69+9f6RjQINzmBQAAAAAAsHZVlQ4AAAAAAAAAAAAAAAAAbB6UjQAAAAAAAAAAAAAAAIAkykYAAAAAAAAAAAAAAABAmbIRAAAAAAAAAAAAAAAAkETZCAAAAAAAAAAAAAAAAChTNgIAAAAAAAAAAAAAAACSKBsBAAAAAAAAAAAAAAAAZcpGAAAAAAAAAAAAAAAAQBJlIwAAAAAAAAAAAAAAAKBM2QgAAAAAAAAAAAAAAABIomwEAAAAAAAAAAAAAAAAlCkbAQAAAAAAAAAAAAAAAEmUjQAAAAAAAAAAAAAAAIAyZSMAAAAAAIB1MiMjWhYpitU/PcbOqXSwDfB+2gsAAAAAAAAbk7IRAAAAAACw/l6+PZ/uum069Lk2/6h0lgZzTCa9XErpD1/Krklywo9TVypl1pidK5xrQ7yf9gIAAAAAAMDGpGwEAAAAAACsv1J96utLKdXXp1TpLO+ofIPPRyfk2UpHaXBb894BAAAAAADYUNWVDgAAAAAAAGyBtjk6V81+KVdVOgcAAAAAAACwUbnZCAAAAAAAAAAAAAAAAEiibAQAAAAAAKyv6UNSXRQpiiJF0SfXv7728R/N+VUuG3BgdmzdLM233yl79zkv02cvWbnMnLE9ynOLFJ1G5rb7r8yZR+yRDts0TePm7dLtsE/kipkLVs5//NKPvDn/oxPy7IoHM05PyxXjH/hkfrPa+sfmuleSzByVHVbMqR6Y6Zv4TN6Xe1+6IA9OuTQjjtw3XTq0TJNmrdNpr6Ny5sTfZkF9ec7ia9N75Tks/3zk0sfL70/Oias+6zt55dKlBbNy1TnHZ7+dt0/zxk3Sou2uOejkz+W6h1942zP+4V9uzZf7H5Rdtm9eHu+Vqxdu6OYAAAAAAABYQdkIAAAAAABYPyden6Wluvz4hHcafzzjR16fDiN/mj89/WwenTI0zX49Nv0Hjc9T5ek7j5mVUunJfGXvJIun5DPnPpiel/4ijz07P0/d/a30WjQl5x7+sXzurpeSJN0ufDil0m0Z3mKN7z3m2rxc+lsu32/14eXrl+cfPD7zSqWUSqWUlk7OiStnPZYJh3dIi+33z3/NWpINtkXufT3NOC8HD7wi8w//RmY89lwW/fOB/OjMbXP76MPzsc/PzOtJ0vr03FFakP8+ullStW++8VQpD1/Ybfn71QMzvfT3jK1tlZNuWJTStIHLx+fdnCEHHJKzb3olJ185M39ftDBP/WZCjl7044yo/VguvO/VtZzxQ7nkrB+l/dk35oF/Lsozsy7K/o02dGMAAAAAAACsStkIAAAAAADYBBZl79O/m0/Udk6rFq2yS++LMuq4Jqm7/7bcsbbbZ17ZNoOu/F5G1O6cNi1aZcf9h+Sq68fkQ0sez/hzJ+aJTZazPvX1pZRK9SmVNtaaW8re11+jXl/OdV/onS5tmqfFB7rk8M9enwmDt8njEy/PtBdXzPpAho4eknb1D+WKb92ZulXeXzpzfCb+49SM6temPPJafv6Fs3Lj35vlpG9NyZf+vVvattwmO+z58Vwy+as5ovTHXP7ZKzL7LUmW5PDP/zBn99o12zdvmnYHXZL7l/4mn/zAJj4AAAAAAACArYCyEQAAAAAAsAnsmQMOaL7Kn5ukc+d2SZ7JM8+sZXqL2nz0I6tfTVPsdXSO7JiUHrktv5y36XKO/s38vLrooVxc23ijrbll7H099ZmUV+78dNqtNtg0e++9W1L3h/zh0TdHmxw1Op/aq8jcSWMz+fkVo4sz+Zs/SPvPjswhNSvGZmb69IVJVc8c32eNplCHI3LEnkn9g9Pzv3PXDLNXDjyw+ZqDAAAAAAAAbATKRgAAAAAAwCbQKq1arT7SuHHjLL9JaC3TW7dO67cMtku7dknyXJ57bhNE3GTep3t/4eFc/5/Dc/heO6dDm2ZpVBQpiiK7nPe7JK/m1VdXndwtnx55TJq+elvGXfnn5UNPXJ1xv+6V0Wd2fXPaG/Mz/4Uk9b/Maa2Wr/fm54P54kNJ8mSefHLNMC3SosWm2yoAAAAAAMDWTNkIAAAAAACovOcXZmFpzcEVRZsVxZskqUpVVZIlS7JktbmLs3jx2hYuUhQbN+pGt0XsfU6u+PjBGfqVX6fNGZPy2ycW5PX6UkqlUv45/uAkpZTW2EPbU0fn1PalPPKdcbn9jSW5fdwVeX746PRrs8qkJh3SoXWS6lMytW75em/9PJ/vHrax9gEAAAAAAMC7UTYCAAAAAAAq7/V7c/f9S1cbKv3pl7n9maTY+9gcvcOK0R2yww5Jnp6buatOfvZ3+d0/1rZw8zRvnlUKOn/Kl3YvctT3Fm3sHWy4zXrv0zOw+sO59NH7c9fMV5MOJ2b0Ob3StW3L1JSLTK+99traX23SO6PP7p7iuRsydty4jL1hh3z23ENTvdqkg3Pyye2Spb/LXTNff8sST379wDTa8Zzcu/QtjwAAAAAAANhElI0AAAAAAIDKa9U4d3zxjPzgvjn51ysv5J8P3JBPDRmbxxp3y+iJ52a3lRN3z1FH7ZQ8MyVf+87vM+/lV7No9q/y9XOnZUG7tS384ey7b03yxMzc+c9X8vx9N+bnf/1wDj1ku/LzxzLh8A5psf3+uWTWkrUtsOlVbO/roVGP9O7VInl2Wr72zTvyxMJXs+S1RXnizgk5/+pH3va1Pc4elaOavZH/u/BLubf3qJzRZc0ZTXPs5ddmxK7zc/V/DMq3bns0z7zwel5bNDt3XzMiJ1wyL/3Gnp+e1WtbHQAAAAAAgE1B2QgAAAAAAFg/04ekuqjJ0FuS5NYMbVak6ZDpyawx6bTGeLcLH04yLX2LIgd8fXaS2blsnyJFn0lZ7R6blsdk7Ld755H/+nj23KFddv3YyNzZZkAm/vrufPPQbVaZ2DiHXnpzrjr9Q/nLV4/MB9t1yoEjbsy2543NkA5Jnr8mhxVFulzwQHl+mwydMClndn80Yz7UNl0G3JG9Jl6XkR9asd6yLF1an1KpPvWld9v4jIxoWaTY57LMTpJbhqamKNJj7JwtcO9r30vxls9JmbIsSTrnU1NuzxVndc+cbw/MPh1bZ/tdeuasHy1Nv8EHJnkh1xxZpNj/a5mz6t4+MDijh3RISh1z2ugBab22Y2338fzg9/fl6hOLTPnMoenatlXadTsi59zcJGfcMjPX9++0fN5azrgo+mbau/21AQAAAAAAsF6KUqn0rj+dAQAAAAAA729FUWTKlCnp379/A3/zU7n0I11z0cJzc9/cCenRwN9eWVvH3v/1w+PS4buH5M8PXJBdKx1mFf369UuSTJ06tcJJAAAAAAAANitTqyudAAAAAAAAgPerObn26nty3OjrN6uiEQAAAAAAAG+vqtIBAAAAAAAAeP9YfO0xaXPSpDzz8vP549Wj8s1/nZkv9m9T6VgAAAAAAACsI2UjAAAAAACgIuaM7ZGi6JqLHkny9MTUFkW6XfhwpWM1iPf73hdPPy3/1mbHHH1Vdf5z8sXZv7rSiQAAAAAAAFhXftoBAAAAAAAqYucxs1IaU+kUlfF+3nvr02ekdHqlUwAAAAAAALCh3GwEAAAAAAAAAAAAAAAAJFE2AgAAAAAAAAAAAAAAAMqUjQAAAAAAAAAAAAAAAIAkykYAAAAAAAA
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot_model(model=prediction_model)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"X = np.expand_dims(train_data[feature_cols].values, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"Y = [x for x in train_data[out_cols].values.T]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"Y_valid = [x for x in valid_data[out_cols].values.T]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from keras.callbacks import ReduceLROnPlateau\n",
"reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 08:54:43.160048: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)\n",
"2023-03-30 08:54:43.178160: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2200000000 Hz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"690/690 [==============================] - 22s 20ms/step - loss: 5.0242 - val_loss: 4.0851\n",
"Epoch 2/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 3.1973 - val_loss: 2.2916\n",
"Epoch 3/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 1.4457 - val_loss: 0.7350\n",
"Epoch 4/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.2939 - val_loss: 0.0348\n",
"Epoch 5/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0098 - val_loss: 0.0061\n",
"Epoch 6/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0066 - val_loss: 0.0052\n",
"Epoch 7/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0063 - val_loss: 0.0050\n",
"Epoch 8/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0059 - val_loss: 0.0052\n",
"Epoch 9/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0059 - val_loss: 0.0049\n",
"Epoch 10/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0056 - val_loss: 0.0053\n",
"Epoch 11/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0055 - val_loss: 0.0045\n",
"Epoch 12/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0052 - val_loss: 0.0039\n",
"Epoch 13/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0053 - val_loss: 0.0047\n",
"Epoch 14/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0049 - val_loss: 0.0045\n",
"Epoch 15/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0046 - val_loss: 0.0033\n",
"Epoch 16/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0046 - val_loss: 0.0040\n",
"Epoch 17/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0047 - val_loss: 0.0048\n",
"Epoch 18/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0044 - val_loss: 0.0039\n",
"Epoch 19/100\n",
"690/690 [==============================] - 12s 18ms/step - loss: 0.0043 - val_loss: 0.0032\n",
"Epoch 20/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0043 - val_loss: 0.0033\n",
"Epoch 21/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0041 - val_loss: 0.0047\n",
"Epoch 22/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0041 - val_loss: 0.0038\n",
"Epoch 23/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0040 - val_loss: 0.0038\n",
"Epoch 24/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0041 - val_loss: 0.0032\n",
"Epoch 25/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0038 - val_loss: 0.0032\n",
"Epoch 26/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0039 - val_loss: 0.0035\n",
"Epoch 27/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0038 - val_loss: 0.0036\n",
"Epoch 28/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0037 - val_loss: 0.0034\n",
"Epoch 29/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0038 - val_loss: 0.0035\n",
"Epoch 30/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0028 - val_loss: 0.0021\n",
"Epoch 31/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0027 - val_loss: 0.0022\n",
"Epoch 32/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0027 - val_loss: 0.0022\n",
"Epoch 33/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0027 - val_loss: 0.0022\n",
"Epoch 34/100\n",
"690/690 [==============================] - 13s 19ms/step - loss: 0.0027 - val_loss: 0.0021\n",
"Epoch 35/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0027 - val_loss: 0.0021\n",
"Epoch 36/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0027 - val_loss: 0.0021\n",
"Epoch 37/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0026 - val_loss: 0.0021\n",
"Epoch 38/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0026 - val_loss: 0.0021\n",
"Epoch 39/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0026 - val_loss: 0.0021\n",
"Epoch 40/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0026 - val_loss: 0.0022\n",
"Epoch 41/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 42/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 43/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 44/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 45/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 46/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 47/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 48/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 49/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 50/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 51/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 52/100\n",
"690/690 [==============================] - 13s 19ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 53/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 54/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 55/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 56/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 57/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 58/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 59/100\n",
"690/690 [==============================] - 13s 19ms/step - loss: 0.0024 - val_loss: 0.0020\n",
"Epoch 60/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 61/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 62/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 63/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 64/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 65/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 66/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 67/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 68/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 69/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 70/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 71/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 72/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 73/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 74/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 75/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 76/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 77/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 78/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 79/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 80/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 81/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 82/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 83/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 84/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 85/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 86/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 87/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 88/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 89/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 90/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 91/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 92/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 93/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 94/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 95/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 96/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 97/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 98/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 99/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n",
"Epoch 100/100\n",
"690/690 [==============================] - 13s 18ms/step - loss: 0.0025 - val_loss: 0.0020\n"
]
}
],
"source": [
"trainable_model.compile(optimizer='adam', loss=None)\n",
"hist = trainable_model.fit([X, Y[0], Y[1], Y[2], Y[3], Y[4], Y[5]], epochs=100, batch_size=64, verbose=1, \n",
" validation_data=[np.expand_dims(valid_data[feature_cols].values, axis=1), Y_valid[0], Y_valid[1], Y_valid[2], Y_valid[3], Y_valid[4], Y_valid[5]],\n",
" callbacks=[reduce_lr]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[0.02649942],\n",
" [0.17850098],\n",
" [0.05590522],\n",
" ...,\n",
" [0.06559962],\n",
" [0.07360396],\n",
" [0.11683977]], dtype=float32),\n",
" array([[0.0409514 ],\n",
" [0.10872066],\n",
" [0.04259145],\n",
" ...,\n",
" [0.08358508],\n",
" [0.07330614],\n",
" [0.07695329]], dtype=float32),\n",
" array([[0.0078212 ],\n",
" [0.03656307],\n",
" [0.01281381],\n",
" ...,\n",
" [0.02448112],\n",
" [0.01346153],\n",
" [0.01845062]], dtype=float32),\n",
" array([[0.03263965],\n",
" [0.2611128 ],\n",
" [0.28920233],\n",
" ...,\n",
" [0.45292783],\n",
" [0.07745293],\n",
" [0.49551144]], dtype=float32),\n",
" array([[0.26260266],\n",
" [0.30972052],\n",
" [0.09751886],\n",
" ...,\n",
" [0.02907404],\n",
" [0.49323225],\n",
" [0.04708111]], dtype=float32),\n",
" array([[0.02804664],\n",
" [0.18090692],\n",
" [0.08362207],\n",
" ...,\n",
" [0.07223988],\n",
" [0.05220559],\n",
" [0.1401439 ]], dtype=float32)]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rst = prediction_model.predict(np.expand_dims(test_data[feature_cols], axis=1))\n",
"rst"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.9999489175147174,\n",
" 0.9999628656168237,\n",
" 0.999970644281569,\n",
" 0.9998601875319162,\n",
" 0.9999327040916789,\n",
" 0.9999352970648614]"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[np.exp(K.get_value(log_var[0]))**0.5 for log_var in trainable_model.layers[-1].log_vars]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"pred_rst = pd.DataFrame.from_records(np.squeeze(np.asarray(rst), axis=2).T, columns=out_cols)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"real_rst = test_data[out_cols].copy()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"for col in out_cols:\n",
" pred_rst[col] = pred_rst[col] * (maxs[col] - mins[col]) + mins[col]\n",
" real_rst[col] = real_rst[col] * (maxs[col] - mins[col]) + mins[col]"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"y_pred_pm25 = pred_rst['PM2.5'].values.reshape(-1,)\n",
"y_pred_pm10 = pred_rst['PM10'].values.reshape(-1,)\n",
"y_pred_so2 = pred_rst['SO2'].values.reshape(-1,)\n",
"y_pred_no2 = pred_rst['NO2'].values.reshape(-1,)\n",
"y_pred_o3 = pred_rst['O3'].values.reshape(-1,)\n",
"y_pred_co = pred_rst['CO'].values.reshape(-1,)\n",
"y_true_pm25 = real_rst['PM2.5'].values.reshape(-1,)\n",
"y_true_pm10 = real_rst['PM10'].values.reshape(-1,)\n",
"y_true_so2 = real_rst['SO2'].values.reshape(-1,)\n",
"y_true_no2 = real_rst['NO2'].values.reshape(-1,)\n",
"y_true_o3 = real_rst['O3'].values.reshape(-1,)\n",
"y_true_co = real_rst['CO'].values.reshape(-1,)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"def print_eva(y_true, y_pred, tp):\n",
" MSE = mean_squared_error(y_true, y_pred)\n",
" RMSE = np.sqrt(MSE)\n",
" MAE = mean_absolute_error(y_true, y_pred)\n",
" MAPE = mean_absolute_percentage_error(y_true, y_pred)\n",
" R_2 = r2_score(y_true, y_pred)\n",
" print(f\"COL: {tp}, MSE: {format(MSE, '.2E')}\", end=',')\n",
" print(f'RMSE: {round(RMSE, 4)}', end=',')\n",
" print(f'MAPE: {round(MAPE, 4) * 100} %', end=',')\n",
" print(f'MAE: {round(MAE, 4)}', end=',')\n",
" print(f'R_2: {round(R_2, 4)}')\n",
" return [MSE, RMSE, MAE, MAPE, R_2]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"COL: pm25, MSE: 8.40E+01,RMSE: 9.166899681091309,MAPE: 12.470000237226486 %,MAE: 5.874800205230713,R_2: 0.9659\n",
"COL: pm10, MSE: 3.62E+02,RMSE: 19.023300170898438,MAPE: 12.839999794960022 %,MAE: 12.942299842834473,R_2: 0.9355\n",
"COL: so2, MSE: 1.01E+02,RMSE: 10.043499946594238,MAPE: 24.879999458789825 %,MAE: 6.045100212097168,R_2: 0.9678\n",
"COL: no2, MSE: 2.25E+01,RMSE: 4.739699840545654,MAPE: 9.269999712705612 %,MAE: 3.3459999561309814,R_2: 0.9641\n",
"COL: o3, MSE: 2.76E+01,RMSE: 5.253300189971924,MAPE: 17.21999943256378 %,MAE: 3.7822999954223633,R_2: 0.9889\n",
"COL: co, MSE: 1.58E-02,RMSE: 0.125900000333786,MAPE: 8.269999921321869 %,MAE: 0.09000000357627869,R_2: 0.9683\n"
]
}
],
"source": [
"pm25_eva = print_eva(y_true_pm25, y_pred_pm25, tp='pm25')\n",
"pm10_eva = print_eva(y_true_pm10, y_pred_pm10, tp='pm10')\n",
"so2_eva = print_eva(y_true_so2, y_pred_so2, tp='so2')\n",
"nox_eva = print_eva(y_true_no2, y_pred_no2, tp='no2')\n",
"o3_eva = print_eva(y_true_o3, y_pred_o3, tp='o3')\n",
"co_eva = print_eva(y_true_co, y_pred_co, tp='co')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MSE</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>MAPE</th>\n",
" <th>R_2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>PM25</th>\n",
" <td>84.032173</td>\n",
" <td>9.166906</td>\n",
" <td>5.874771</td>\n",
" <td>0.124713</td>\n",
" <td>0.965868</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PM10</th>\n",
" <td>361.884674</td>\n",
" <td>19.023266</td>\n",
" <td>12.942302</td>\n",
" <td>0.128373</td>\n",
" <td>0.935509</td>\n",
" </tr>\n",
" <tr>\n",
" <th>SO2</th>\n",
" <td>100.872444</td>\n",
" <td>10.043528</td>\n",
" <td>6.045063</td>\n",
" <td>0.248850</td>\n",
" <td>0.967751</td>\n",
" </tr>\n",
" <tr>\n",
" <th>NO2</th>\n",
" <td>22.465204</td>\n",
" <td>4.739747</td>\n",
" <td>3.346048</td>\n",
" <td>0.092662</td>\n",
" <td>0.964141</td>\n",
" </tr>\n",
" <tr>\n",
" <th>O3</th>\n",
" <td>27.596682</td>\n",
" <td>5.253254</td>\n",
" <td>3.782317</td>\n",
" <td>0.172198</td>\n",
" <td>0.988936</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CO</th>\n",
" <td>0.015848</td>\n",
" <td>0.125889</td>\n",
" <td>0.090029</td>\n",
" <td>0.082749</td>\n",
" <td>0.968320</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MSE RMSE MAE MAPE R_2\n",
"PM25 84.032173 9.166906 5.874771 0.124713 0.965868\n",
"PM10 361.884674 19.023266 12.942302 0.128373 0.935509\n",
"SO2 100.872444 10.043528 6.045063 0.248850 0.967751\n",
"NO2 22.465204 4.739747 3.346048 0.092662 0.964141\n",
"O3 27.596682 5.253254 3.782317 0.172198 0.988936\n",
"CO 0.015848 0.125889 0.090029 0.082749 0.968320"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame.from_records([pm25_eva, pm10_eva, so2_eva, nox_eva, o3_eva,co_eva], columns=['MSE', 'RMSE', 'MAE', 'MAPE', 'R_2'], index=['PM25', 'PM10', 'SO2', 'NO2', 'O3', 'CO'])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fe480840ed0>"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n",
"findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6UAAAIICAYAAACW1EjCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eZxsd33YeX9+Z62119v3Xt2r5WpDCCGQ2CwssB2wEWANyBhiJTEvB2fiOHHGGscjB2U8jx0HD4syJDPzZHnwA4aM8UKMLDuxjbAg4OCwRFggsUgg0Hb3e/v2WlXn1Fl+88fvnOrqvaq7qqu6+vt+vfTq29Xd1adbXVXne76b0lojhBBCCCGEEEIMgjXoAxBCCCGEEEIIcXBJUCqEEEIIIYQQYmAkKBVCCCGEEEIIMTASlAohhBBCCCGEGBgJSoUQQgghhBBCDIwEpUIIIYQQQgghBsYZ9AEAHDp0SJ84cWLQhyGEEEIIIYQQog+++tWvXtRaz2z0saEISk+cOMEjjzwy6MMQQgghhBBCCNEHSqlnN/uYlO8KIYQQQgghhBgYCUqFEEIIIYQQQgyMBKVCCCGEEEIIIQZmKHpKNxJFESdPniQIgkEfSl8VCgUuv/xyXNcd9KEIIYQQQgghxJ4b2qD05MmTVKtVTpw4gVJq0IfTF1prZmdnOXnyJFdfffWgD0cIIYQQQggh9tzQlu8GQcD09PTIBqQASimmp6dHPhsshBBCCCGEEJsZ2qAUGOmANHcQfkYhhBBCCCGE2MxQB6Wj5HOf+xx33nnnoA9DCCGEEEIIIYbK0PaUduvBR09x/0NPcnq+wbGJIvfecQN33Xq87983SRJs2+779xFCCCGEEEKIUdRxplQpZSulHlVK/efs/Sml1F8opb6bvZ1s+9z7lFJPKaWeVErd0Y8Db/fgo6e474HHOTXfQAOn5hvc98DjPPjoqV3d7zPPPMMLX/hCfuZnfoaXvOQlvP3tb6der3PixAl+4zd+g9e85jX8x//4H/n0pz/Nq1/9al72spfxjne8g+XlZQA+9alP8cIXvpDXvOY1PPDAAz34SYUQQgghhBBitHSTKb0H+DYwlr3/buAzWuv3KaXenb3/T5VSLwLuBm4CjgEPK6VeoLVOdnqQ//w/fZNvnV7c9OOPPjdPM0lX3daIEn7lDx/j977y3IZf86JjY/za/3DTtt/7ySef5MMf/jC33347P/uzP8u//bf/FjCrXL7whS9w8eJF3va2t/Hwww9TLpd5//vfzwc/+EF+5Vd+hb//9/8+n/3sZ7nuuuv4qZ/6qS5+YiGEEEIIIYQ4GDrKlCqlLgd+HPj/t938VuBj2b8/BtzVdvvva61DrfXTwFPAq3pytJtYG5Bud3s3rrjiCm6//XYAfvqnf5ovfOELAK0g80tf+hLf+ta3uP3227nlllv42Mc+xrPPPssTTzzB1VdfzfXXX49Sip/+6Z/e9bEIIYQQQgghxKjpNFP6r4FfAapttx3RWp8B0FqfUUodzm4/Dnyp7fNOZrft2HYZzdvf91lOzTfW3X58osgf/INX7+Zbr5uOm79fLpcBs2v0x37sx/i93/u9VZ/3ta99TSbrCiGEEEIIIcQ2ts2UKqXuBM5rrb/a4X1uFInpDe7355RSjyilHrlw4UKHd72xe++4gaK7ethQ0bW5944bdnW/AM899xxf/OIXAfi93/s9XvOa16z6+G233cZf/dVf8dRTTwFQr9f5zne+wwtf+EKefvppvve977W+VgghhBBCCCHEap2U794OvEUp9Qzw+8DrlFK/A5xTSl0GkL09n33+SeCKtq+/HDi99k611h/SWr9Ca/2KmZmZXfwIcNetx3nv227m+EQRhcmQvvdtN/dk+u6NN97Ixz72MV7ykpdw6dIl/uE//IerPj4zM8NHP/pR/tbf+lu85CUv4bbbbuOJJ56gUCjwoQ99iB//8R/nNa95DVddddWuj0UIIYQQQgghRo3Sel0Sc/NPVupHgP9Fa32nUup+YLZt0NGU1vpXlFI3Ab+L6SM9BnwGuH6rQUeveMUr9COPPLLqtm9/+9vceOON3f48PfXMM89w55138o1vfKOv32cYflYhhBBCCCGE6Bel1Fe11q/Y6GO72VP6PuATSqm/BzwHvANAa/1NpdQngG8BMfALu5m8K4QQQgghhBBidHW8pxRAa/05rfWd2b9ntdav11pfn7291PZ5v6m1vlZrfYPW+s97fdB75cSJE33PkgohhBBCCCHEppbOwm+/CZbODfpI+qaroFQIIYQQQgghxB76/AfguS/B598/6CPpm92U7wohhBBCCCGE6If3HIY4XHn/kQ+b/xwffvX85l+3D0mmVAghhBBCCCGGzT2PwYvfAZZr3ncKcPM74J7HB3tcfSBBqRBCCCGEEEIMm+pR8KuQxub9OAR/DKpHBntcfSBBaR+dOHGCixcvDvowhBBCCCGEEPtR7Twcu9X8+8U/CcujOexotILSPk6m0lqTpmnP71cIIYQQQgghNnT3x+GaHzb/fu0vm/dH0GgFpT2eTPXMM89w44038o/+0T/iZS97Gf/iX/wLXvnKV/KSl7yEX/u1X2t93l133cXLX/5ybrrpJj70oQ/15HsLIYQQQgghRGvYUdIc7HH00f6Yvvvn74azWzT0PvdXoPXK+/lkKqXgyts3/pqjN8Ob3rftt37yySf57d/+be666y7+8A//kK985StorXnLW97CX/7lX/JDP/RDfOQjH2FqaopGo8ErX/lKfvInf5Lp6ekuf0ghhBBCCCGEWCMOzNskGuxx9NFoZEqPvRJKM6CyH0dZUJ6B46/c9V1fddVV3HbbbXz605/m05/+NLfeeisve9nLeOKJJ/jud78LwP/1f/1fvPSlL+W2227j+eefb90uhBBCCCGEELuSZ0rT0Q1K90emtIOMJv/pl+CvP2pGJSdNuPEtcOcHd/2ty+UyYHpK77vvPv7BP/gHqz7+uc99jocffpgvfvGLlEolfuRHfoQgCHb9fYUQQgghhBCCqGHejnD57mhkSsFMpnr5u+B/fNi87fFkqjvuuIOPfOQjLC8vA3Dq1CnOnz/PwsICk5OTlEolnnjiCb70pS/19PsKIYQQQgghDrBWT2k82OPoo/2RKe1E+ySqHmRI13rDG97At7/9bV796lcDUKlU+J3f+R3e+MY38u///b/nJS95CTfccAO33XZbz7+3EEIIIYQQ4oBq9ZSObqZ0dILSPjhx4gTf+MY3Wu/fc8893HPPPes+78///M83/PpnnnmmX4cmhBBCCCGEOAgOQE/p6JTvCiGEEEIIIcSokem7QgghhBBCCCEGptVTKkGpEEIIIYQQQoi9dgB6Soc6KNVaD/oQ+u4g/IxCCCGEEEKIHZKgdHAKhQKzs7MjHbRprZmdnaVQKAz6UIQQQgghhBDDKA9KU1kJs+cuv/xyTp48yYULFwZ9KH1VKBS4/PLLB30YQgghhBBCiGHU6ikd3Uzp0Aalruty9dVXD/owhBBCCCGEEGJwZPquEEIIIYQQQoiBSNOVDKkEpUIIIYQQQggh9lQSrvw7laBUCCGEEEIIIcReykt3YaR7SiUoFUIIIYQQQohhFLUHpZIpFUIIIYQQQgixl2IJSoUQQgghhBBCDErc1lMq5btCCCGEEEIIIfZUe6Y0jQd3HH0mQakQQgghhBBCDCPJlAohhBBCCCGEGBjpKRVCCCGEEEIIMTCrMqUSlAohhBBCCCGE2Etxw7x1ClK+K4QQQgghhBBij+WZUr8KqWRKhRBCCCGEEELspbyn1K9K+a4QQgghhBBCiD3WnimVoFQIIYQQQgghxJ7KM6VeVXpKhRBCCCGEEELssfby3TQe7LH0kQSlQgghhBBCCDGMogCUBW5RMqVCCCGEEEIIIfZYHJh1MLYnPaVCCCGEEEIIIfZYHILjg+1KUCqEEEIIIYQQYo+1MqWulO8KIYQQQgghhNhjcbhSvptKplQIIYQQQgghxF7KM6WWI+W7QgghhBBCCCH2WKunVAYdCSGEEEIIIYTYa3Fjpac0jUDrQR9RX0hQKoQQQgghhBD
"text/plain": [
"<Figure size 1152x648 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(16, 9))\n",
"plt.plot(pred_rst['PM10'].values[50:150], 'o-', label='pred')\n",
"plt.plot(real_rst['PM10'].values[50:150], '*-', label='real')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-03-30 09:42:42.029138: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n",
"WARNING:absl:Found untraced functions such as first_attn_layer_call_fn, first_attn_layer_call_and_return_conditional_losses, layer_normalization_layer_call_fn, layer_normalization_layer_call_and_return_conditional_losses, layer_normalization_1_layer_call_fn while saving (showing 5 of 450). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: ./models/uw_loss_lookback/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: ./models/uw_loss_lookback/assets\n"
]
}
],
"source": [
"prediction_model.save(f'./models/uw_loss_lookback/')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from statistics import mean\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import explained_variance_score,r2_score, median_absolute_error, mean_squared_error, mean_absolute_error\n",
"from scipy import stats\n",
"import numpy as np\n",
"from matplotlib import rcParams\n",
"config = {\"font.size\": 32,\"mathtext.fontset\":'stix'}\n",
"rcParams.update(config)\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"config = {\"font.size\": 32,\"mathtext.fontset\":'stix'}\n",
"rcParams.update(config)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def scatter_out_1(x, y, label, name): ## x,y为两个需要做对比分析的两个量。\n",
" # ==========计算评价指标==========\n",
" BIAS = mean(x - y)\n",
" MSE = mean_squared_error(x, y)\n",
" RMSE = np.power(MSE, 0.5)\n",
" R2 = r2_score(x, y)\n",
" MAE = mean_absolute_error(x, y)\n",
" EV = explained_variance_score(x, y)\n",
" print('==========算法评价指标==========')\n",
" print('Explained Variance(EV):', '%.3f' % (EV))\n",
" print('Mean Absolute Error(MAE):', '%.3f' % (MAE))\n",
" print('Mean squared error(MSE):', '%.3f' % (MSE))\n",
" print('Root Mean Squard Error(RMSE):', '%.3f' % (RMSE))\n",
" print('R_squared:', '%.3f' % (R2))\n",
" # ===========Calculate the point density==========\n",
" xy = np.vstack([x, y])\n",
" z = stats.gaussian_kde(xy)(xy)\n",
" # ===========Sort the points by density, so that the densest points are plotted last===========\n",
" idx = z.argsort()\n",
" x, y, z = x[idx], y[idx], z[idx]\n",
" def best_fit_slope_and_intercept(xs, ys):\n",
" m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs)))\n",
" b = mean(ys) - m * mean(xs)\n",
" return m, b\n",
" m, b = best_fit_slope_and_intercept(x, y)\n",
" regression_line = []\n",
" for a in x:\n",
" regression_line.append((m * a) + b)\n",
" fig,ax=plt.subplots(figsize=(12,9),dpi=400)\n",
" scatter=ax.scatter(x,y,marker='o',c=z*100,s=15,label='LST',cmap='Spectral_r')\n",
" cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='Frequency')\n",
" min_value = min(min(x), min(y))\n",
" max_value = max(max(x), max(y))\n",
"\n",
" plt.plot([min_value-5,max_value+5],[min_value-5,max_value+5],'black',lw=1.5) # 画的1:1线线的颜色为black线宽为0.8\n",
" plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线\n",
" plt.axis([min_value-5,max_value+5,min_value-5,max_value+5]) # 设置线的范围\n",
" plt.xlabel('Measured %s' % label)\n",
" plt.ylabel('Retrived %s' % label)\n",
" # plt.xticks(fontproperties='Times New Roman')\n",
" # plt.yticks(fontproperties='Times New Roman')\n",
"\n",
"\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.95), '$N=%.f$' % len(y)) # text的位置需要根据x,y的大小范围进行调整。\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.88), '$R^2=%.2f$' % R2)\n",
" plt.text(min_value-5 + (max_value-min_value) * 0.05, int(max_value * 0.81), '$RMSE=%.2f$' % RMSE)\n",
" plt.xlim(min_value-5,max_value+5) # 设置x坐标轴的显示范围\n",
" plt.ylim(min_value-5,max_value+5) # 设置y坐标轴的显示范围\n",
" # file_name = name.split('(')[0].strip()\n",
" plt.savefig(f'./figure/lookback/{name}.png',dpi=800,bbox_inches='tight',pad_inches=0)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.939\n",
"Mean Absolute Error(MAE): 12.942\n",
"Mean squared error(MSE): 361.885\n",
"Root Mean Squard Error(RMSE): 19.023\n",
"R_squared: 0.936\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n",
"findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEjAAAAyVCAYAAAAv6W6kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdeVjVZd7H8c+NBzcUwtLUSsuyMKcFt8ymiHKjxRqn5ekB29Q0qmeammkmraZpUttn2jSlrCa1BoooS0VEhTJts902M6MFLI1IXA9wP3/ozHQU8MD5LQd4v66La65zL9/7c44/ceIcvj9jrRUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEBDxPgdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAND00MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1GAyMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANBgNDACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANRgMjAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADQYDQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADUYDIwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0GA0MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1GAyMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANBgNDACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANRgMjAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADQYDQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADUYDIwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0GA0MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1GAyMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANBgNDACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANRgMjAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADQYDQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADUYDIwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0GA0MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1GAyMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANBgNDACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANRgMjAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADQYDQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADUYDIwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0GA0MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA1GAyMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANBgNDACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANRgMjAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADQYDQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADUYDIwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA0GA0MAIAAAAAAAAAAAAAAAAAAAAAAAAAAAA8ZIwxfmcAAAAAACfQwAgAAAAAAAAAAAAAAAAAAAAAAAAAAADwVrrfAQAAAADACcZa63cGAAAAAAAAAAAAAAAAAAAAAAAAAAAAoEUwxsRK+k7SMdbaMr/zAAAAAEAkYvwOAAAAAAAAAAAAAAAAAAAAAAAAAAAAALQgp0s6QNJov4MAAAAAQKRoYAQAAAAAAAAAAAAAAAAAAAAAAAAAAAB454Ld/3u+rykAAAAAwAHGWut3BgAAADSCMaabpLP2GF4naasPcQAAAAAAAAAAAAAAAAAAAAAAQPRrL6nXHmMvWWtL/QjTEhljYiVtkJQoqUbSQdbaMn9TAQAAAEDjBfwOAAAAgEY7S9Isv0MAAAAAAAAAAAAAAAAAAAAAAIAm7QpJWX6HaEFO167mRZIUI2m0pOn+xQEAAACAyMT4HQAAAAAAAAAAAAAAAAAAAAAAAAAAAABoIS7Y4/H5vqQAAAAAAIcE/A4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAOMj6HaA2wWBQiYmJKi8v/89YTEzMqWVlZbZr164+JquX8TsAAAAAgOgW43cAAAAAAAAAAAAAAAAAAAAAAAAAAAAAoLkrLCwMaV4kSTU1NcrNzfUpEQAAAABELuB3AAAAADTauj0HZs6cqWOOOcaPLAAAAAAAAAAAAAAAAAAAAAAAwGdVVVW655579OKLLzZk216/nwB3ZGdn1zqek5OjzMxMj9MAAAAAgDOMtdbvDAAAAGgEY8yJkl775dhrr72mE0880adEAAAAAAAAAAAAAAAAAAAAAADAL5s3b9aFF16ohQsXNnTrEGvtSjcy+Sjqfnk2GAzqwAMPVHl5+V5zMTEx+vbbb9W1a1cfku2T8TsAAAAAgOgW43cAAAAAAAAAAAAAAAAAAAAAAAAAAAAANF5paalSUlIa07wIHiksLKy1eZEk1dTUKDc31+NEAAAAAOCMgN8BAAAAAAAAAAAAAAAAAAAAAAAAAPjDWqt1n2/UmvfKtP6LH1X6XYV2bK9SINBKnfZvr56Hd9IRR3XWsf0PUuvWrfyOCwCoxUcffaQzzjhDJSUlfkdBPbKzs+udz8nJUWZmpkdpAAAAAMA5NDACAAAAAAAAAAAAAAAAAAAAAAAAWpia6hq9svQLFbz8qb5eX17rmrLvftaaD8okSR06ttEpQ49Q2jl9FL9fOy+jAgDqsWzZMv3mN79RRUWF31FQj2AwqLy8vHrXFBcXq6ysTF27dvUmFAAAAAA4JMbvAAAAAAAAAAAAAAAAAAAAAAAAAAC88903Fbr9xnzNfnhVnc2L9lS5eYcWPP+Rbrxmvla98qWstS6nBADsy9y5czVixIiwmxfFxsbqL3/5i8upUJvCwkKVl4f+m9tPnUMe19TUKDc318tYAAAAAOAIGhgBAAAAAAAAAAAAAAAAAAAAAAAALcTbq0p0y3Uv64vPNjZqf+XmHZpx76t6cuYbqqmucTgdACAc1lpNnTpVGRkZCgaDYe1JSEhQfn6+RowY4XI61CY7OzvksZGUriMVp0DIeE5OjoepAAAAAMAZNDACAAAAAAAAAAAAAAAAAAAAAAAAWoDVr3+th+4qVnBndcS1li36TI89vEo1NdaBZACAcFVVVWnChAmaPHly2Ht69OihFStWKDU11cVkqEswGFReXl7IWJISlWjaqJ86h4wXFxerrKzMw3QAAAAAEDkaGAEAAAAAAAAAAAAAAAAAAAAAAADN3HdfV2jGva842nDo1aVfaPH8jx2rBwCo3+bNmzVq1ChlZWWFvSc5OVmrVq1S3759XUyG+hQWFqq8vDxkbIC6SJIG7v7ff6upqVFubq5n2QAAAADACTQwAgAAAAAAAAAAAAAAAAAAAAAAAJqxmuoaPfrga9q5s9rx2s/OfVel31Y4XhcAEKq0tFQpKSlauHBh2HvS0tJUXFysbt26uZgM+5KdnR3y2Ejqr86SpCQlKk6BkPmcnByvogEAAACAI2hgBAAAAAAAAAAAAAAAAAAAAAAAADRjry5bpy8+2+hK7eDOas177C1XagMAdvnoo480ePBgvfPOO2HvGT9+vF588UV16NDBxWTYl2AwqLy8vJCxJCUq3rSWJAVMjPrtbmb0b8XFxSorK/MqIgAAAABEjAZGAAAAAAAAAAAAAAAAAAAAAAAAQDNlrdXilz5x9Yz3V3+nsu9+dvUMAGipli1bppNOOkklJSVh75k6dapmzpypQCDgYjKEo7CwUOXl5SFjA9Ql5PHAPR7X1NQoNzfX9WwAAAAA4BQaGAEAAAAAAAAAAAAAAAAAAAAAAADN1LrPN+nr9eX7XhihosWfu34GALQ0c+fO1YgRI1RRURHW+tjYWM2ZM0c33nijjDEup0M4srOzQx4bSf3VOWQsSYmKU2izqZycHLejAQAAAIBjaGAEAAAAAAA
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['PM10'].values, pred_rst['PM10'].values, label='$PM_{10}\\ (\\mu g/m^3$)', name='PM10')"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.967\n",
"Mean Absolute Error(MAE): 5.875\n",
"Mean squared error(MSE): 84.032\n",
"Root Mean Squard Error(RMSE): 9.167\n",
"R_squared: 0.966\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdeXhU5f3+8fsTErYAISgIqKBYbZSqZbOIVUQBwSq11mr9BrGoiKK2LtUKaGutoNalrQsIKGpFahONUSwQMUCiCG7gihtuVE1wi5GdCXl+f+BPO5CESeY5c2aG9+u6uLzmnPPc554cGCGZ+RxzzgkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgETKCLsAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGDXw/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEDCMfwGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBwDL8BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQcw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAnH8BsAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQMIx/AYAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkHAMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBzDbwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACcfwGwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAwjH8BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQcAy/AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkHMNvAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJx/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEDCMfwGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBwDL8BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQcw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAnH8BsAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQMIx/AYAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkHAMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBzDbwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACcfwGwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAwjH8BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQcAy/AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkHMNvAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJx/AbAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEDCMfwGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBwDL8BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQcw28AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAnH8BsAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQMIx/AYAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkHAMvwEAAAAAAAAAAAAAAAAAAAAAAAAAAAASyMws7A4AAABAMmD4DQAAAAAAAAAAAAAAAAAAAAAAAAAAAJBY+WEXAAAAAJKBOefC7gAAAAAAAAAAAAAAAAAAAAAAAAAAAADsEswsS9Knkg52zlWG3QcAAAAIU0bYBQAAAAAAAAAAAAAAAAAAAAAAAAAAAIBdyLGSdpd0cthFAAAAgLAx/AYAAAAAAAAAAAAAAAAAAAAAAAAAAABInFO//e+vQm0BAAAAJAFzzoXdAQAAAE1gZl0knbDd5vclbQihDgAAAAAAAAAAAAAAAAAAAAAASH6tJfXYbtsTzrmKMMrsiswsS9IaSbmSaiXt6ZyrDLcVAAAAEJ7MsAsAAACgyU6QND3sEgAAAAAAAAAAAAAAAAAAAAAAIKWdK2lG2CV2Icdq2+AbScqQdLKkKeHVAQAAAMKVEXYBAAAAAAAAAAAAAAAAAAAAAAAAAAAAYBdx6naPfxVKCwAAACBJZIZdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPDIhV2gLpFIRLm5uaqqqvpuW0ZGxtGVlZWuc+fOITZrkIVdAAAAAOktI+wCAAAAAAAAAAAAAAAAAAAAAAAAAAAAQLorLS2NGnwjSbW1tSoqKgqpEQAAABC+zLALAAAAoMne337DtGnTdPDBB4fRBQAAAAAAAAAAAAAAAAAAAAAAhKympkY333yzHn/88cYs2+HzCQhGQUFBndsLCws1bty4BLcBAAAAkoM558LuAAAAgCYws8MlPfu/25599lkdfvjhITUCAAAAAAAAAAAAAAAAAAAAAABhWbt2rU477TTNmzevsUsHOOeWBtEpREn34dlIJKI99thDVVVVO+zLyMjQJ598os6dO4fQbKcs7AIAAABIbxlhFwAAAAAAAAAAAAAAAAAAAAAAAAAAAEDTVVRUaODAgU0ZfIMEKS0trXPwjSTV1taqqKgowY0AAACA5MDwGwAAAAAAAAAAAAAAAAAAAAAAAAAAgBT1xhtvqH///lqxYkXYVdCAgoKCBvcXFhYmqAkAAACQXBh+AwAAAAAAAAAAAAAAAAAAAAAAAAAAkIIWLVqkI444QqtXrw67ChoQiURUXFzc4DHl5eWqrKxMTCEAAAAgiTD8BgAAAAAAAAAAAAAAAAAAAAAAAAAAIMU8+OCDOu6441RdXR3T8VlZWfrTn/4UcCvUpbS0VFVVVVHbeqtj1OPa2loVFRUlshYAAACQFBh+AwAAAAAAAAAAAAAAAAAAAAAAAAAAkCKcc5o8ebJGjhypSCQS05qcnByVlJTouOOOC7gd6lJQUBD12CTl6wBlKzNqe2FhYQJbAQAAAMmB4TcAAAAAAAAAAAAAAAAAAAAAAAAAAAApoKamRmPHjtXEiRNjXtOtWzctWbJEgwYNCrAZ6hOJRFRcXBy1LU+5yrUW6q2OUdvLy8tVWVmZwHYAAABA+Bh+AwAAAAAAAAAAAAAAAAAAAAAAAAAAkOTWrl2rESNGaMaMGTGv6dWrl5YtW6aePXsG2AwNKS0tVVVVVdS2vuokSer37X//v9raWhUVFSWsGwAAAJAMGH4DAAAAAAAAAAAAAAAAAAAAAAAAAACQxCoqKjRw4EDNmzcv5jXDhw9XeXm5unTpEmAz7ExBQUHUY5PURx0lSXnKVbYyo/YXFhYmqhoAAACQFBh+AwAAAAAAAAAAAAAAAAAAAAAAAAAAkKTeeOMN9e/fXytWrIh5zZgxY/T444+rTZs2ATbDzkQiERUXF0dty1Ou2llzSVKmZaj3t4Nw/r/y8nJVVlYmqiIAAAAQOobfAAAAAAAAAAAAAAAAAAAAAAAAAAAAJKFFixbpiCOO0OrVq2NeM3nyZE2bNk2ZmZkBNkMsSktLVVVVFbWtrzpFPe633ePa2loVFRUF3g0AAABIFgy/AQAAAAAAAAAAAAAAAAAAAAAAAAAASDIPPvigjjvuOFVXV8d0fFZWlmbNmqXx48fLzAJuh1gUFBREPTZJfdQxaluecpWt6EFFhYWFQVcDAAAAkgbDbwAAAAAAAAAAAAAAAAAAAAAAAAAAAJKEc06TJ0/WyJEjFYlEYlqTk5OjkpIS5efnB9wOsYpEIiouLo7alqdctbPmUdsyLUO9txuIU15ersrKyqArAgAAAEmB4TcAAAAAAAAAAAAAAAAAAAAAAAAAAABJoKamRmPHjtXEiRNjXtOtWzctWbJEgwYNCrAZGqu0tFRVVVVR2/qqU53H9ttue21trYqKigLrBgAAACQTht8AAAAAAAAAAAAAAAAAAAAAAAAAAACEbO3atRoxYoRmzJgR85pevXpp2bJl6tmzZ4DN0BQFBQVRj01SH3Ws89g85SpbmVHbCgsLg6oGAAAAJBWG3wAAAAAAAAAAAAAAAAAAAAAAAAAAAISooqJCAwcO1Lx582JeM3z4cJWXl6tLly4BNkNTRCIRFRcXR23LU67aWfM6j8+0DPXebjB
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['PM2.5'].values, pred_rst['PM2.5'].values, label='$PM_{2.5} (\\mu g/m^3$)', name='PM25')"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.968\n",
"Mean Absolute Error(MAE): 6.045\n",
"Mean squared error(MSE): 100.872\n",
"Root Mean Squard Error(RMSE): 10.044\n",
"R_squared: 0.968\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdeZjVdfk//vuMAygIOBqomKiYhpEVuCGmiAKCuWWpnwI1FQSxvvVpsUQtM0HLbLFcgNwSqGZyGsWAAUYYFHEpzFLU3MlkXMeRnRnm/P7o86uGZTjDnHPeZ+DxuK65vN6v+3W/Xs/BES+YM/dJpdPpAAAAAAAAAAAAAAAAAAAAAACAfCpKOgAAAAAAAAAAAAAAAAAAAAAAADsew28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAAAAAAMg7w28AAAAAAAAAAAAAAAAAACCPUqlUKukMAABQCAy/AQAAAAAAAAAAAAAAAACA/BqRdAAAACgEqXQ6nXQGAAAAAAAAAAAAAAAAAADYIaRSqXYR8UZEHJpOp2uSzgMAAEkqSjoAAAAAAAAAAAAAAAAAAADsQE6MiA9FxJlJBwEAgKQZfgMAAAAAAAAAAAAAAAAAAPlz9v/986xEUwAAQAFIpdPppDMAALANUqnU3hFxykbLL0fE6gTiAAAAAAAAAAAAAAAAha9jRPTaaO2BdDq9PIkwO6JUKtUuIt6MiJKIaIyIfdLpdE2yqQAAIDnFSQcAAGCbnRIRk5MOAQAAAAAAAAAAAAAAtGkXR8SUpEPsQE6Mfw2+iYgoiogzI+KW5OIAAECyipIOAAAAAAAAAAAAAAAAAAAAO4izN3o+K5EUAABQIIqTDgAAAAAAAAAAAAAAAAAAAFmUTjrA5tTX10dJSUnU1tb+e62oqOj4mpqa9F577ZVgsmalkg4AAMD2rSjpAAAAAAAAAAAAAAAAAAAAsL2rqqpqMvgmIqKxsTHKy8sTSgQAAMkrTjoAAADb7OWNFyZNmhSHHnpoElkAAAAAAAAAAAAAAICENTQ0xI9//OO4//77W9K2yc8nkBulpaWbXS8rK4tx48blOQ0AABSGVDqdTjoDAADbIJVKHR0Rj/z32iOPPBJHH310QokAAAAAAAAAAAAAAICkrFixIs4555yYNWtWS1sHpNPpxbnIlKCC++HZ+vr62HPPPaO2tnaTWlFRUfzzn/+MvfbaK4FkW5VKOgAAANu3oqQDAAAAAAAAAAAAAAAAAACw7ZYvXx4DBw7clsE35ElVVdVmB99ERDQ2NkZ5eXmeEwEAQGEw/AYAAAAAAAAAAAAAAAAAoI165plnon///vHkk08mHYVmlJaWNlsvKyvLUxIAACgsht8AAAAAAAAAAAAAAAAAALRB8+fPj2OOOSaWLVuWdBSaUV9fHxUVFc3uWbhwYdTU1OQnEAAAFBDDbwAAAAAAAAAAAAAAAAAA2php06bFSSedFHV1dRntb9euXXzve9/LcSo2p6qqKmpra5us9YtuTZ4bGxujvLw8n7EAAKAgGH4DAAAAAAAAAAAAAAAAANBGpNPpmDhxYowcOTLq6+sz6unatWtUVlbGSSedlON0bE5paWmT51REjIiDo1MUN1kvKyvLYyoAACgMht8AAAAAAAAAAAAAAAAAALQBDQ0NMWbMmLjiiisy7unZs2csWrQoBg0alMNkbEl9fX1UVFQ0WesdJVGS6hD9oluT9YULF0ZNTU0e0wEAQPIMvwEAAAAAAAAAAAAAAAAAKHArVqyI0047LaZMmZJxT9++fePRRx+NPn365DAZzamqqora2toma4dH94iIOOL//vn/a2xsjPLy8rxlAwCAQmD4DQAAAAAAAAAAAAAAAABAAVu+fHkMHDgwZs2alXHP8OHDY+HChbH33nvnMBlbU1pa2uQ5FRGHRbeIiOgdJdEpipvUy8rK8hUNAAAKguE3AAAAAAAAAAAAAAAAAAAF6plnnon+/fvHk08+mXHP6NGj4/77749dd901h8nYmvr6+qioqGiy1jtKokuqfUREFKeKot//DcL5/y1cuDBqamryFREAABJn+A0AAAAAAAAAAAAAAAAAQAGaP39+HHPMMbFs2bKMeyZOnBiTJk2K4uLiHCYjE1VVVVFbW9tk7fDo3uT5iI2eGxsbo7y8POfZAACgUBh+AwAAAAAAAAAAAAAAAABQYKZNmxYnnXRS1NXVZbS/Xbt2MXXq1Lj88ssjlUrlOB2ZKC0tbfKciojDoluTtd5REp2i6aCisrKyXEcDAICCYfgNAAAAAAAAAAAAAAAAAECBSKfTMXHixBg5cmTU19dn1NO1a9eorKyMESNG5Dgdmaqvr4+Kiooma72jJLqk2jdZK04VRb+NBuIsXLgwampqch0RAAAKguE3AAAAAAAAAAAAAAAAAAAFoKGhIcaMGRNXXHFFxj09e/aMRYsWxaBBg3KYjJaqqqqK2traJmuHR/fN7j1io/XGxsYoLy/PWTYAACgkht8AAAAAAAAAAAAAAAAAACRsxYoVcdppp8WUKVMy7unbt288+uij0adPnxwmY1uUlpY2eU5FxGHRbbN7e0dJdIriJmtlZWW5igYAAAWleOtbAAAAAAAAAAAAAAAAAGiJdWvr47VXauON1+ti3ZqGKNopFbuV7BL79do9uu25a6RSqaQjAgVk+fLl8ZnPfCaefPLJjHuGDx8epaWlseuuu+YwGduivr4+Kioqmqz1jpLokmq/2f3FqaLol+4WD8Xyf68tXLgwampqYq+99splVAAASJzhNwAAAAAAAAAAAAAAAABZ0FC/IZ5YvCzmV/49/v7s25FuTG92X5euO0f/Y/ePE4d/NPbap0ueUwKF5plnnomTTz45li1blnHP6NGj45ZbboniYj8mWoiqqqqitra2ydrh0b3ZniOie5PhN42NjVFeXh7jxo3LSUYAACgURUkHAAAAAAAAAAAAAAAAAGjrnnz8H/GtsRVx208ejuefeWuLg28iIj6oWxtzHnguvn3pfTHpZ4ti5Qfr8pgUKCTz58+PY445pkWDbyZOnBiTJk0y+KaAlZaWNnlORcRh0a3Znt5REp2i6b/TsrKybEcDAICCY/gNAAAAAAAAAAAAAAAAwDZav64hpvx8Ufxs4oJ4793VLe5/ZMHLMf7/3R/PPLU8B+mAQjZt2rQ46aSToq6uLqP97dq1i6lTp8bll18eqVQqx+nYVvX19VFRUdFkrXeURJdU+2b7ilNF0W+jATkLFy6MmpqabEcEAICCYvgNAAAAAAAAAAA
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['SO2'].values, pred_rst['SO2'].values, label='$SO_2\\ (\\mu g/m^3)$', name='SO2')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.965\n",
"Mean Absolute Error(MAE): 3.346\n",
"Mean squared error(MSE): 22.465\n",
"Root Mean Squard Error(RMSE): 4.740\n",
"R_squared: 0.964\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEjAAAAyVCAYAAAAv6W6kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdf3zWdb0//uc1N1AQ5jRQMDEtPRhZAf4gPDqXgNAP9divT2dTM0EQ69SpUydHv4+gZXVOmSKQ9kOg2mouKGDihK1I/BH2S8o0NbI2TV3Loeg1dn3/6Hw7bsC4xq7rem9wv99u163er+fr+Xo9NhFvwMXzSmUymQAAAAAAAAAAAAAAAAAAAAAAAOiLoqQDAAAAAAAAAAAAAAAAAAAAAAAAg48BRgAAAAAAAAAAAAAAAAAAAAAAQJ8ZYAQAAAAAAAAAAAAAAAAAAAAAAPSZAUYAAAAAAAAAAAAAAAAAAAAAAECfGWAEAAAAAAAAAAAAAAAAAAAAAAD0mQFGAAAAAAAAAAAAAAAAAAAAAABAnxlgBAAAAAAAAAAAAAAAAAAAAAAA9JkBRgAAAAAAAAAAAAAAAAAAAAAAQJ8ZYAQAAAAAAAAAAAAAAAAAAAAAAPSZAUYAAAAAAAAAAAAAAAAAAAAAAECfGWAEAAAAAAAAAAAAAAAAAAAAAAD0mQFGAAAAAAAAAAAAAAAAAAAAAABAnxlgBAAAAAAAAAAAAAAAAAAAAAAA9JkBRgAAAAAAAAAAAAAAAAAAAAAAQJ8ZYAQAAAAAAAAAAAAAAAAAAAAAAPSZAUYAAAAAAAAAAAAAAAAAAAAAAECfGWAEAAAAAAAAAAAAAAAAAAAAAAD0mQFGAAAAAAAAAAAAAAAAAAAAAABAnxlgBAAAAAAAAAAAAAAAAAAAAAAA9JkBRgAAAAAAAAAAAAAAAAAAAAAAQJ8ZYAQAAAAAAAAAAAAAAAAAAAAAAPSZAUYAAAAAAAAAAAAAAAAAAAAAAECfGWAEAAAAAAAAAAAAAAAAAAAAAAD0mQFGAAAAAAAAAAAAAAAAAAAAAABAnxlgBAAAAAAAAAAAAAAAAAAAAAAA9JkBRgAAAAAAAAAAAAAAAAAAAAAAQJ8ZYAQAAAAAAAAAAAAAAAAAAAAAAPSZAUYAAAAAAAAAAAAAAAAAAAAAAECfGWAEAAAAAAAAAAAAAAAAAAAAAAD0mQFGAAAAAAAAAAAAAAAAAAAAAABAnxlgBAAAAAAAAAAAAAAAAAAAAAAA9JkBRgAAAAAAAAAAAAAAAAAAUECpVCqVdAYAAIBcMMAIAAAAAAAAAAAAAAAAAAAKqzLpAAAAALmQymQySWcAAAAAAAAAAAAAAAAAAIADQiqVKomIP0fEyZlMpjXpPAAAAP1RlHQAAAAAAAAAAAAAAAAAAAA4gJwTES+LiAuTDgIAANBfBhgBAAAAAAAAAAAAAAAAAEDhvPN///cdiaYAAADIgVQmk0k6AwAA+yCVSo2JiLf0WH4kIp5LIA4AAAAAAAAAAAAAADDwDYuI43us/TCTybQkEeZAlEqlSiLiiYgoi4iuiDg6k8m0JpsKAABg3xUnHQAAgH32lohYmnQIAAAAAAAAAAAAAABgULs8IpYlHeIAck78fXhRRERRRFwYETcmFwcAAKB/ipIOAAAAAAAAAAAAAAAAAAAAB4h39nh+RyIpAAAAcqQ46QAAAAAAAAAAAAAAAAAAAJBDmaQD7E46nY6ysrJoa2v7x1pRUdHZra2tmaOOOirBZL1KJR0AAAAY2IqSDgAAAAAAAAAAAAAAAAAAAPu7xsbGbsOLIiK6urqirq4uoUQAAAD9V5x0AAAA9tkjPReWLFkSJ598chJZAAAAAAAAAAAAAACAhHV2dsYXvvCFWLVqVV/advn7CeRHTU3Nbtdra2tj/vz5BU4DAACQG6lMJpN0BgAA9kEqlXpDRPz0pWs//elP4w1veENCiQAAAAAAAAAAAAAAgKQ8++yz8a53vSvWrl3b19apmUzmrnxkStCA+8uz6XQ6jjzyyGhra9ulVlRUFH/605/iqKOOSiDZXqWSDgAAAAxsRUkHAAAAAAAAAAAAAAAAAABg37W0tER5efm+DC+iQBobG3c7vCgioqurK+rq6gqcCAAAIDcMMAIAAAAAAAAAAAAAAAAAGKQeeOCBmDJlStx///1JR6EXNTU1vdZra2sLlAQAACC3DDACAAAAAAAAAAAAAAAAABiENmzYEGeccUZs27Yt6Sj0Ip1OR319fa97mpubo7W1tTCBAAAAcsgAIwAAAAAAAAAAAAAAAACAQWbFihVx7rnnRnt7e1b7S0pK4lOf+lSeU7E7jY2N0dbW1m1tUozq9tzV1RV1dXWFjAUAAJATBhgBAAAAAAAAAAAAAAAAAAwSmUwmFi1aFFVVVZFOp7PqKS0tjYaGhjj33HPznI7dqamp6faciojKODGGR3G39dra2gKmAgAAyA0DjAAAAAAAAAAAAAAAAAAABoHOzs6YO3duLFiwIOuecePGxaZNm6KioiKPydiTdDod9fX13dbGR1mUpYbGpBjVbb25uTlaW1sLmA4AAKD/DDACAAAAAAAAAAAAAAAAABjgnn322TjvvPNi2bJlWfdMnDgxNm/eHBMmTMhjMnrT2NgYbW1t3dZOidEREXHq//7v/6+rqyvq6uoKlg0AACAXDDACAAAAAAAAAAAAAAAAABjAWlpaory8PNauXZt1z6xZs6K5uTnGjBmTx2TsTU1NTbfnVERMjlERETE+ymJ4FHer19bWFioaAABAThhgBAAAAAAAAAAAAAAAAAAwQD3wwAMxZcqUuP/++7PumTNnTqxatSoOPfTQPCZjb9LpdNTX13dbGx9lMTI1JCIiilNFMel/hxn9/5qbm6O1tbVQEQEAAPrNACMAAAAAAAAAAAAAAAAAgAFow4YNccYZZ8S2bduy7lm0aFEsWbIkiouL85iMbDQ2NkZbW1u3tVNidLfnU3s8d3V1RV1dXd6zAQAA5IoBRgAAAAAAAAAAAAAAAAAAA8yKFSvi3HPPjfb29qz2l5SUxPLly+Oqq66KVCqV53Rko6ampttzKiImx6hua+OjLIZH92FTtbW1+Y4GAACQMwYYAQAAAAAAAAAAAAAAAAAMEJlMJhYtWhRVVVWRTqez6iktLY2GhoaorKzMczqylU6no76+vtva+CiLkakh3daKU0UxqcdQo+bm5mhtbc13RAAAgJwwwAgAAAAAAAAAAAAAAAAAYADo7OyMuXPnxoIFC7LuGTduXGzatCkqKirymIy+amxsjLa2tm5rp8To3e49tcd6V1dX1NXV5S0bAABALhlgBAAAAAAAAAAAAAAAAACQsGeffTbOO++8WLZsWdY9EydOjM2bN8eECRPymIx9UVNT0+05FRGTY9Ru946Pshgexd3Wamtr8xUNAAAgpwwwAgAAAAAAAAAAAAAAAABIUEtLS5SXl8fatWuz7pk1a1Y0NzfHmDFj8piMfZFOp6O+vr7b2vgoi5GpIbvdX5wqikk9hhs1NzdHa2trviICAADkjAFGAAAAAAAAAAAAAAAAAAAJeeCBB2LKlClx//33Z90zZ86cWLVqVRx66KF5TMa+amxsjLa2tm5rp8ToXntO7VHv6uqKurq6nGcDAADINQOMAAAAAAAAAAAAAAAAAAASsGHDhjjjjDNi27ZtWfcsWrQolixZEsXFxXlMRn/U1NR0e05FxOQY1WvP+CiL4dH9n2ltbW2uowEAAOScAUYAAAAAAAAAAAAAAAAAAAW2YsWKOPfcc6O9vT2r/SUlJbF8+fK46qqrIpVK5Tkd+yqdTkd9fX23tfFRFiNTQ3rtK04VxaQeQ46am5ujtbU11xEBAAByygAjAAAAAAAAAAAAAAAAAIACyWQysWjRoqiqqop0Op1VT2lpaTQ0NERlZWWe09FfjY2N0dbW1m3tlBidVe+pPfZ1dXVFXV1dzrIBAADkgwFGAAAAAAAAAAAAAAAAAAAF0NnZGXPnzo0FCxZk3TNu3LjYtGlTVFRU5DEZuVJTU9PtORURk2NUVr3joyyGR3G3tdra2lxFAwAAyAsDjAA
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['NO2'].values, pred_rst['NO2'].values, label='$NO_2\\ (\\mu g/m^3)$', name='NO2')"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.989\n",
"Mean Absolute Error(MAE): 3.782\n",
"Mean squared error(MSE): 27.597\n",
"Root Mean Squard Error(RMSE): 5.253\n",
"R_squared: 0.989\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEb8AAAyVCAYAAAChBVopAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOzdfXzWdb0/8Pc1NlAQ5jQQMPGmNIyswJsIj+IUEMq74+nm9AM1EwSxTp1uPDm6P4KW3ZwsRSA1E6i2nFM8wMQJw1DUxLQkTfOGzE1T5xQEvcau3x/nnGrAxjV2Xdd3g+fz8dijx/fz/rw/n9dwYrBr7yuVyWQCAAAAAAAAAAAAAAAAAAAAAAAKqSjpAAAAAAAAAAAAAAAAAAAAAAAA7HkMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAAAAAoOAMvwEAAAAAAAAAAAAAAAAAgAJKpVKppDMAAEB3YPgNAAAAAAAAAAAAAAAAAAAU1uSkAwAAQHeQymQySWcAAAAAAAAAAAAAAAAAAIA9QiqVKomI5yPiqEwm05h0HgAASFJR0gEAAAAAAAAAAAAAAAAAAGAPckpEvC0izk46CAAAJM3wGwAAAAAAAAAAAAAAAAAAKJyP/e//fjTRFAAA0A2kMplM0hkAANgFqVRqSEScts3yUxHxRgJxAAAAAAAAAAAAAACA7q9vRBy2zdrtmUymIYkwe6JUKlUSES9ERFlEtEbEgZlMpjHZVAAAkJzipAMAALDLTouI+UmHAAAAAAAAAAAAAAAAerQLI2JB0iH2IKfE/wy+iYgoioizI+Ka5OIAAECyipIOAAAAAAAAAAAAAAAAAAAAe4iPbfP80URSAABAN1GcdAAAAAAAAAAAAAAAAAAAAMihTNIBdiSdTkdZWVk0NTX9ba2oqOikxsbGzODBgxNM1qFU0gEAANi9FSUdAAAAAAAAAAAAAAAAAAAAdnd1dXVtBt9ERLS2tkZ1dXVCiQAAIHnFSQcAAGCXPbXtwrx58+Koo45KIgsAAAAAAAAAAAAAAJCwlpaW+O53vxu33XZbZ9q2+/kE8qOysnKH61VVVTFz5swCpwEAgO4hlclkks4AAMAuSKVSH4yIe/5x7Z577okPfvCDCSUCAAAAAAAAAAAAAACS8vrrr8fHP/7xWLZsWWdbx2QymXvzkSlB3e6HZ9PpdBxwwAHR1NS0Xa2oqCj+8pe/xODBgxNItlOppAMAALB7K0o6AAAAAAAAAAAAAAAAAAAAu66hoSHGjh27K4NvKJC6urodDr6JiGhtbY3q6uoCJwIAgO7B8BsAAAAAAAAAAAAAAAAAgB7q0UcfjdGjR8dDDz2UdBQ6UFlZ2WG9qqqqQEkAAKB7MfwGAAAAAAAAAAAAAAAAAKAHWrlyZRx//PGxYcOGpKPQgXQ6HTU1NR3uWb16dTQ2NhYmEAAAdCOG3wAAAAAAAAAAAAAAAAAA9DCLFi2KU089NZqbm7PaX1JSEl//+tfznIodqauri6ampjZro2Jgm+fW1taorq4uZCwAAOgWDL8BAAAAAAAAAAAAAAAAAOghMplMzJkzJ6ZMmRLpdDqrntLS0qitrY1TTz01z+nYkcrKyjbPqYiYHEdEvyhus15VVVXAVAAA0D0YfgMAAAAAAAAAAAAAAAAA0AO0tLTE9OnTY9asWVn3DBs2LNasWRPl5eV5TEZ70ul01NTUtFkbHmVRluoTo2Jgm/XVq1dHY2NjAdMBAEDyDL8BAAAAAAAAAAAAAAAAAOjmXn/99TjjjDNiwYIFWfeMHDky1q5dGyNGjMhjMjpSV1cXTU1NbdaOiUEREXHs//7v/2ltbY3q6uqCZQMAgO7A8BsAAAAAAAAAAAAAAAAAgG6soaEhxo4dG8uWLcu6Z9KkSbF69eoYMmRIHpOxM5WVlW2eUxFxdAyMiIjhURb9orhNvaqqqlDRAACgWzD8BgAAAAAAAAAAAAAAAACgm3r00Udj9OjR8dBDD2XdM23atLjttttin332yWMydiadTkdNTU2bteFRFgNSvSMiojhVFKP+dxDO/1m9enU0NjYWKiIAACTO8BsAAAAAAAAAAAAAAAAAgG5o5cqVcfzxx8eGDRuy7pkzZ07MmzcviouL85iMbNTV1UVTU1ObtWNiUJvnY7d5bm1tjerq6rxnAwCA7sLwGwAAAAAAAAAAAAAAAACAbmbRokVx6qmnRnNzc1b7S0pKYuHChXHppZdGKpXKczqyUVlZ2eY5FRFHx8A2a8OjLPpF20FFVVVV+Y4GAADdhuE3AAAAAAAAAAAAAAAAAADdRCaTiTlz5sSUKVMinU5n1VNaWhq1tbUxefLkPKcjW+l0OmpqatqsDY+yGJDq3WatOFUUo7YZiLN69epobGzMd0QAAOgWDL8BAAAAAAAAAAAAAAAAAOgGWlpaYvr06TFr1qyse4YNGxZr1qyJ8vLyPCajs+rq6qKpqanN2jExaId7j91mvbW1Naqrq/OWDQAAuhPDbwAAAAAAAAAAAAAAAAAAEvb666/HGWecEQsWLMi6Z+TIkbF27doYMWJEHpOxKyorK9s8pyLi6Bi4w73Doyz6RXGbtaqqqnxFAwCAbsXwGwAAAAAAAAAAAAAAAACABDU0NMTYsWNj2bJlWfdMmjQpVq9eHUOGDMljMnZFOp2OmpqaNmvDoywGpHrvcH9xqihGbTMYZ/Xq1dHY2JiviAAA0G0YfgMAAAAAAAAAAAAAAAAAkJBHH300Ro8eHQ899FDWPdOmTYvbbrst9tlnnzwmY1fV1dVFU1NTm7VjYlCHPcduU29tbY3q6uqcZwMAgO7G8BsAAAAAAAAAAAAAAAAAgASsXLkyjj/++NiwYUPWPXPmzIl58+ZFcXFxHpPRFZWVlW2eUxFxdAzssGd4lEW/aPvPtKqqKtfRAACg2zH8BgAAAAAAAAAAAAAAAACgwBYtWhSnnnpqNDc3Z7W/pKQkFi5cGJdeemmkUqk8p2NXpdPpqKmpabM2PMpiQKp3h33FqaIYtc2AnNWrV0djY2OuIwIAQLdi+A0AAAAAAAAAAAAAAAAAQIFkMpmYM2dOTJkyJdLpdFY9paWlUVtbG5MnT85zOrqqrq4umpqa2qwdE4Oy6j12m32tra1RXV2ds2wAANAdGX4DAAAAAAAAAAAAAAAAAFAALS0tMX369Jg1a1bWPcOGDYs1a9ZEeXl5HpORK5WVlW2eUxFxdAzMqnd4lEW/KG6zVlVVlatoAADQLRl+AwAAAAAAAAAAAAAAAACQZ6+//nqcccY
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_1(real_rst['O3'], pred_rst['O3'], label='$O_3 \\ (\\mu g/m^3)$', name='O3')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"def scatter_out_2(x, y, name): ## x,y为两个需要做对比分析的两个量。\n",
" # ==========计算评价指标==========\n",
" BIAS = mean(x - y)\n",
" MSE = mean_squared_error(x, y)\n",
" RMSE = np.power(MSE, 0.5)\n",
" R2 = r2_score(x, y)\n",
" MAE = mean_absolute_error(x, y)\n",
" EV = explained_variance_score(x, y)\n",
" print('==========算法评价指标==========')\n",
" print('Explained Variance(EV):', '%.3f' % (EV))\n",
" print('Mean Absolute Error(MAE):', '%.3f' % (MAE))\n",
" print('Mean squared error(MSE):', '%.3f' % (MSE))\n",
" print('Root Mean Squard Error(RMSE):', '%.3f' % (RMSE))\n",
" print('R_squared:', '%.3f' % (R2))\n",
" # ===========Calculate the point density==========\n",
" xy = np.vstack([x, y])\n",
" z = stats.gaussian_kde(xy)(xy)\n",
" # ===========Sort the points by density, so that the densest points are plotted last===========\n",
" idx = z.argsort()\n",
" x, y, z = x[idx], y[idx], z[idx]\n",
" def best_fit_slope_and_intercept(xs, ys):\n",
" m = (((mean(xs) * mean(ys)) - mean(xs * ys)) / ((mean(xs) * mean(xs)) - mean(xs * xs)))\n",
" b = mean(ys) - m * mean(xs)\n",
" return m, b\n",
" m, b = best_fit_slope_and_intercept(x, y)\n",
" regression_line = []\n",
" for a in x:\n",
" regression_line.append((m * a) + b)\n",
" fig,ax=plt.subplots(figsize=(12,9),dpi=400)\n",
" scatter=ax.scatter(x,y,marker='o',c=z*100,s=15,label='LST',cmap='Spectral_r')\n",
" cbar=plt.colorbar(scatter,shrink=1,orientation='vertical',extend='both',pad=0.015,aspect=30,label='frequency')\n",
"\n",
" plt.plot([0, 6], [0, 6],'black',lw=1.5) # 画的1:1线线的颜色为black线宽为0.8\n",
" plt.plot(x,regression_line,'red',lw=1.5) # 预测与实测数据之间的回归线\n",
" plt.axis([0,6,0,6]) # 设置线的范围\n",
" plt.xlabel(f'Measured {name}')\n",
" plt.ylabel(f'Retrived {name}')\n",
" # plt.xticks(fontproperties='Times New Roman')\n",
" # plt.yticks(fontproperties='Times New Roman')\n",
"\n",
"\n",
" plt.text(0.3, 5.5, '$N=%.f$' % len(y)) # text的位置需要根据x,y的大小范围进行调整。\n",
" plt.text(0.3, 5.0, '$R^2=%.2f$' % R2)\n",
" plt.text(0.3, 4.5, '$RMSE=%.2f$' % RMSE)\n",
" plt.xlim(0, 6) # 设置x坐标轴的显示范围\n",
" plt.ylim(0, 6) # 设置y坐标轴的显示范围\n",
" file_name = name.split('(')[0].strip()\n",
" plt.savefig(f'./figure/lookback/CO.png',dpi=800,bbox_inches='tight',pad_inches=0)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========算法评价指标==========\n",
"Explained Variance(EV): 0.970\n",
"Mean Absolute Error(MAE): 0.090\n",
"Mean squared error(MSE): 0.016\n",
"Root Mean Squard Error(RMSE): 0.126\n",
"R_squared: 0.968\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEKQAAAzZCAYAAADHXzXhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAD2EAAA9hAHVrK90AAEAAElEQVR4nOz9e5yXdZ0//j8ucFGcScmKAgXTDrspFuuiu0Nr5qBhtfbzW+AeUCrpaAcNt7a20JVt+1SWlG6pKYWdG6TFLUCztEIJSqcjU5EyaqcVRbN0lHS8fn/oto4M8Ia53nPNDPf77fa+tTyv6/W8Hm9iJxHejynKsgwAAAAAAAAAAAAAAAAAAAAA/K9RdQcAAAAAAAAAAAAAAAAAAAAAYGhRSAEAAAAAAAAAAAAAAAAAAABAHwopAAAAAAAAAAAAAAAAAAAAAOhDIQUAAAAAAAAAAAAAAAAAAAAAfSikAAAAAAAAAAAAAAAAAAAAAKAPhRQAAAAAAAAAAAAAAAAAAAAA9KGQAgAAAAAAAAAAAAAAAAAAAIA+FFIAAAAAAAAAAAAAAAAAAAAA0IdCCgAAAAAAAAAAAAAAAAAAAAD6UEgBAAAAAAAAAAAAAAAAAAAAQB8KKQAAAAAAAAAAAAAAAAAAAADoQyEFAAAAAAAAAAAAAAAAAAAAAH0opAAAAAAAAAAAAAAAAAAAAACgD4UUAAAAAAAAAAAAAAAAAAAAAPShkAIAAAAAAAAAAAAAAAAAAACAPhRSAAAAAAAAAAAAAAAAAAAAANCHQgoAAAAAAAAAAAAAAAAAAAAA+lBIAQAAAAAAAAAAAAAAAAAAAEAfCikAAAAAAAAAAAAAAAAAAAAA6EMhBQAAAAAAAAAAAAAAAAAAAAB9KKQAAAAAAAAAAAAAAAAAAAAAoA+FFAAAAAAAAAAAAAAAAAAAAAD0oZACAAAAAAAAAAAAAAAAAAAAgD4UUgAAAAAAAAAAAAAAAAAAAADQh0IKAAAAAAAAAAAAAAAAAAAAAPpQSAEAAAAAAAAAAAAAAAAAAABAHwopAAAAAAAAAAAAAAAAAAAAAOhDIQUAAAAAAAAAAAAAAAAAAAAAfSikAAAAAAAAAAAAAAAAAAAAAKAPhRQAAAAAAAAAAAAAAAAAAAAA9KGQAgAAAAAAAAAAAAAAAAAAAIA+FFIAAAAAAAAAAAAAAAAAAAAA0IdCCgAAAAAAAAAAAAAAAAAAAAD6UEgBAAAAAAAAAAAAAAAAAAAAQB8KKQAAAAAAAAAAAAAAAAAAAADoQyEFAAAAAAAAAAAAAAAAAAAAAH0opAAAAAAAAAAAAAAAAAAAAACgD4UUAAAAAAAAAAAAAAAAAAAAAPShkAIAAAAAAAAAAAAAAAAAAACAPhRSAAAAAAAAAAAAAAAAAAAAANCHQgoAAAAAAAAAAAAAAAAAAAAA+lBIAQAAAAAAAAAAAAAAAAAAAEAfCikAAAAAAAAAAAAAAAAAAAAA6GOPugMADFRRFC1JpiQ5JMlBSZ6eZGKSJyXZL8k+ScYk2TNJkWTLo6/7k9yVZPOjr18lueXR14YkPy/L8sFBeyMAAAAAAAAAAAAAAAAAAABDRFGWZd0ZAHZKURQHJzkmyVFJpid5Zh4pmqjag3mkmOJHSdYlWZvk+2VZ/rEJzwIAAAAAAAAAAAAAAAAAABgyFFIAw0JRFNOSnJTk75I8p8YoW5J8J8k3Hn19ryzLh2rMAwAAAAAAAAAAAAAAAAAAUDmFFMCQVRTFk5O8Jskrk/xFzXG25XdJViX57yRXlmX5u1rTAAAAAAAAAAAAAAAAAAAAVEAhBTDkFEXx50nenmROkr1qjrMzHkzy9SRfSvJfZVn+vuY8AAAAAAAAAAAAAAAAAAAAu0QhBTBkFEXxjCRnJ/mnJKNrjjNQW5KcWpbl5+sOAgAAAAAAAAAAAAAAAAAAsLP2qDsAQFEUrUnek+RtScbUHKcqeyaZWHcIAAAAAAAAAAAAAAAAAACAXaGQAqhVURR/l+TiKG8AAAAAAAAAAAAAAAAAAAAYMkbVHQDYPRVFsU9RFJcl+UqUUQAAAAAAAAAAAAAAAAAAAAwpe9QdANj9FEUxNcnSJM+sOQoAAAAAAAAAAAAAAAAAAAD9GFV3AGD3UhTFKUm+E2UUAAAAAAAAAAAAAAAAAAAAQ5ZCCmDQFEVxTpJPJ9mr7iwAAAAAAAAAAAAAAAAAAABs2x51BwBGvqIoRiW5NMmr684CAAAAAAAAAAAAAAAAAADAjimkAJqqKIo9knw2yd/XnQUAAAAAAAAAAAAAAAAAAIDGjKo7ADByFUUxKsooAAAAAAAAAAAAAAAAAAAAhp096g4AjGgXZnDLKP6Y5IYkP01y86OvjUnuSXLfo697k4xOstejr32SPC3JU5NMTPLsJH+R5M+TTBrE7AAAAAAAAAAAAAAAAAAAAEOGQgqgKYqi+LckrxuER92Y5KtJvpVkbVmW9zdwpjePlFf8PsmmJDf1d1NRFE9J8tdJjkzyt0men2RMBZkBAAAAAAAAAAAAAAAAAACGNIUUQOWKovj7JGc38RG/SfK5JJeVZbm+WQ8py/KOPFJ28dUkKYpi7yQvTDIzyf+XZFKzng0AAAAAAAAAAAAAAAAAAFCnoizLujMAI0hRFM9LsjbJXk1Y/9sk70vyibIs/9iE/Q0riqJI8tdJZiX5pyQT+rnt7WVZfmhQgwEAAAAAAAAAAAAAAAAAAFRgVN0BgJGjKIqxSb6Q6sso7kvy9iQHl2X5n3WXUSRJ+Yi1ZVn+c5JJSV6W5IokD9WbDAAAAAAAAAAAAAAAAAAAYOAUUgBV+nCS51S8c02SqWVZfqgsywcq3l2Jsix7y7L8SlmWJyY5OMm5SX5XaygAAAAAAAAAAAAAAAAAAIABKMqyrDsDMAIURXFCkv+ucOXDSRYkeX9Zlg9XuHdQFEXRmuSpZVneXHcWAAAAAAAAAAAAAAAAAACAnaWQAhiwoiieluTHSZ5c0coHkpxSluXlFe0DAAAAAAAAAAAAAAAAAABgJ+xRdwBgRPhYqiujuDvJy8qyvK6ifQAAAAAAAAAAAAAAAAAAAOykoizLujMAw1hRFC9Mcm1F63qStJdlua6ifQAAAAAAAAAAAAAAAAAAAOyCUXUHAIavoihGJflIRet6k/yDMgoAAAAAAAAAAAAAAAAAAID6KaQABmJekudVtOutZVl+paJdAAAAAAAAAAAAAAAAAAAADEBRlmXdGYBhqCiKfZL8Isn4CtZ9uSzLV1SwBwAAAAAAAAAAAAAAAAAAgAqMqjsAMGz9c6opo/htktdVsAcAAAAAAAAAAAAAAAAAAICKFGVZ1p0BGGaKotg7yW1JnlTBuheXZXllBXsAAAAAAAAAAAAAAAAAAACoyKi6AwDD0qtTTRnFFcooAAAAAAAAAAAAAAAAAAAAhp6iLMu6MwDDSFEUo5L8PMkzB7jqoSSHlmW5YeCpAAAAAAAAAAAAAAAAAAAAqNKougMAw86JGXgZRZJcpIwCAAAAAAAAAAAAAAAAAABgaCrKsqw7AzCMFEVxfZLpA1yzJcmBZVneXkEkAAAAAAAAAAAAAAAAAAAAKjaq7gDA8FEUxZ9n4GUUSfIFZRQAAAAAAAAAAAAAAAAAAABDl0IKYGf8Y0V7PlLRHgAAAAAAAAAAAAAAAAAAAJpAIQWwM/6+gh3fLMvyhxXsAQAAAAAAAAAAAAAAAAAAoEkUUgANKYpiapK/qGDVJRXsAAAAAAAAAAAAAAAAAAAAoIkUUgCN+ocKdjyQ5CsV7AEAAAAAAAAAAAAAAAAAAKCJFFIAjTqpgh0ry7L8QwV7AAAAAAAAAAAAAAAAAAAAaCKFFMAOFUXx50kOqmDVlyrYAQAAAAAAAAAAAAAAAAAAQJMppAAacUwFOx5MsqKCPQAAAAAAAAAAAAAAAAAAADSZQgqgES+sYMfasizvq2APAAAAAAAAAAAAAAAAAAAATaaQAmjECyvYcU0FOwAAAAAAAAAAAAAAAAAAABgECimA7SqK4jlJnlrBqm9UsAMAAAAAAAAAAAAAAAAAAIBBoJAC2JEXVrCjJ8m6CvYAAAA
"text/plain": [
"<Figure size 4800x3600 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"scatter_out_2(y_true_co, y_pred_co, name='$CO \\ (mg/m^3)$')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py37",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "e35e91facd2b4cfa08991d112893a00c4d14d1c91c990d1b62f3056d14d2f283"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}