coal_materials/multi-task0102.ipynb

1622 lines
60 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "6b84fefd-5936-4da4-ab6b-5b944329ad1d",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_DEVICE_ORDER'] = 'PCB_BUS_ID'\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9cf130e3-62ef-46e0-bbdc-b13d9d29318d",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt\n",
"#新增加的两行\n",
"from pylab import mpl\n",
"# 设置显示中文字体\n",
"mpl.rcParams[\"font.sans-serif\"] = [\"SimHei\"]\n",
"\n",
"mpl.rcParams[\"axes.unicode_minus\"] = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "752381a5-0aeb-4c54-bc48-f9c3f8fc5d17",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead tr th {\n",
" text-align: left;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr>\n",
" <th></th>\n",
" <th>Unnamed: 0_level_0</th>\n",
" <th>氢</th>\n",
" <th>碳</th>\n",
" <th>氮</th>\n",
" <th>氧</th>\n",
" <th>弹筒发热量</th>\n",
" <th>挥发分</th>\n",
" <th>固定炭</th>\n",
" </tr>\n",
" <tr>\n",
" <th></th>\n",
" <th>化验编号</th>\n",
" <th>Had</th>\n",
" <th>Cad</th>\n",
" <th>Nad</th>\n",
" <th>Oad</th>\n",
" <th>Qb,ad</th>\n",
" <th>Vad</th>\n",
" <th>Fcad</th>\n",
" </tr>\n",
" <tr>\n",
" <th></th>\n",
" <th>Unnamed: 0_level_2</th>\n",
" <th>(%)</th>\n",
" <th>(%)</th>\n",
" <th>(%)</th>\n",
" <th>(%)</th>\n",
" <th>MJ/kg</th>\n",
" <th>(%)</th>\n",
" <th>(%)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2720110529</td>\n",
" <td>3.93</td>\n",
" <td>70.18</td>\n",
" <td>0.81</td>\n",
" <td>25.079</td>\n",
" <td>27.820</td>\n",
" <td>32.06</td>\n",
" <td>55.68</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2720096883</td>\n",
" <td>3.78</td>\n",
" <td>68.93</td>\n",
" <td>0.77</td>\n",
" <td>26.512</td>\n",
" <td>27.404</td>\n",
" <td>29.96</td>\n",
" <td>54.71</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2720109084</td>\n",
" <td>3.48</td>\n",
" <td>69.60</td>\n",
" <td>0.76</td>\n",
" <td>26.148</td>\n",
" <td>27.578</td>\n",
" <td>29.31</td>\n",
" <td>55.99</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2720084708</td>\n",
" <td>3.47</td>\n",
" <td>66.71</td>\n",
" <td>0.76</td>\n",
" <td>29.055</td>\n",
" <td>26.338</td>\n",
" <td>28.58</td>\n",
" <td>53.87</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2720062721</td>\n",
" <td>3.87</td>\n",
" <td>68.78</td>\n",
" <td>0.80</td>\n",
" <td>26.542</td>\n",
" <td>27.280</td>\n",
" <td>29.97</td>\n",
" <td>54.78</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223</th>\n",
" <td>2720030490</td>\n",
" <td>4.12</td>\n",
" <td>68.85</td>\n",
" <td>0.97</td>\n",
" <td>26.055</td>\n",
" <td>27.864</td>\n",
" <td>32.94</td>\n",
" <td>51.89</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224</th>\n",
" <td>2720028633</td>\n",
" <td>3.97</td>\n",
" <td>67.04</td>\n",
" <td>0.94</td>\n",
" <td>28.043</td>\n",
" <td>27.368</td>\n",
" <td>31.88</td>\n",
" <td>51.38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>225</th>\n",
" <td>2720028634</td>\n",
" <td>4.12</td>\n",
" <td>68.42</td>\n",
" <td>0.96</td>\n",
" <td>26.493</td>\n",
" <td>27.886</td>\n",
" <td>33.16</td>\n",
" <td>52.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>226</th>\n",
" <td>2720017683</td>\n",
" <td>3.88</td>\n",
" <td>67.42</td>\n",
" <td>0.94</td>\n",
" <td>27.760</td>\n",
" <td>26.616</td>\n",
" <td>31.65</td>\n",
" <td>50.56</td>\n",
" </tr>\n",
" <tr>\n",
" <th>227</th>\n",
" <td>2720017678</td>\n",
" <td>3.81</td>\n",
" <td>66.74</td>\n",
" <td>0.92</td>\n",
" <td>28.530</td>\n",
" <td>26.688</td>\n",
" <td>31.02</td>\n",
" <td>50.82</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>228 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Unnamed: 0_level_0 氢 碳 氮 氧 弹筒发热量 挥发分 固定炭\n",
" 化验编号 Had Cad Nad Oad Qb,ad Vad Fcad\n",
" Unnamed: 0_level_2 (%) (%) (%) (%) MJ/kg (%) (%)\n",
"0 2720110529 3.93 70.18 0.81 25.079 27.820 32.06 55.68\n",
"1 2720096883 3.78 68.93 0.77 26.512 27.404 29.96 54.71\n",
"2 2720109084 3.48 69.60 0.76 26.148 27.578 29.31 55.99\n",
"3 2720084708 3.47 66.71 0.76 29.055 26.338 28.58 53.87\n",
"4 2720062721 3.87 68.78 0.80 26.542 27.280 29.97 54.78\n",
".. ... ... ... ... ... ... ... ...\n",
"223 2720030490 4.12 68.85 0.97 26.055 27.864 32.94 51.89\n",
"224 2720028633 3.97 67.04 0.94 28.043 27.368 31.88 51.38\n",
"225 2720028634 4.12 68.42 0.96 26.493 27.886 33.16 52.00\n",
"226 2720017683 3.88 67.42 0.94 27.760 26.616 31.65 50.56\n",
"227 2720017678 3.81 66.74 0.92 28.530 26.688 31.02 50.82\n",
"\n",
"[228 rows x 8 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_0102 = pd.read_excel('./data/20240102/20240102.xlsx', header=[0,1,2])\n",
"data_0102"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "972f1e9c-3ebc-45cf-8d1f-7611645e5238",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['化验编号',\n",
" '氢Had(%)',\n",
" '碳Cad(%)',\n",
" '氮Nad(%)',\n",
" '氧Oad(%)',\n",
" '弹筒发热量Qb,adMJ/kg',\n",
" '挥发分Vad(%)',\n",
" '固定炭Fcad(%)']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cols = [''.join([y for y in x if 'Unnamed' not in y]) for x in data_0102.columns]\n",
"cols"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c95f1106-b3a4-43c6-88ec-3cdebf91d79a",
"metadata": {},
"outputs": [],
"source": [
"data_0102.columns = cols"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2e96af0a-feda-4a1f-a13e-9c8861c6f4d4",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>化验编号</th>\n",
" <th>氢Had(%)</th>\n",
" <th>碳Cad(%)</th>\n",
" <th>氮Nad(%)</th>\n",
" <th>氧Oad(%)</th>\n",
" <th>弹筒发热量Qb,adMJ/kg</th>\n",
" <th>挥发分Vad(%)</th>\n",
" <th>固定炭Fcad(%)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2720110529</td>\n",
" <td>3.93</td>\n",
" <td>70.18</td>\n",
" <td>0.81</td>\n",
" <td>25.079</td>\n",
" <td>27.820</td>\n",
" <td>32.06</td>\n",
" <td>55.68</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2720096883</td>\n",
" <td>3.78</td>\n",
" <td>68.93</td>\n",
" <td>0.77</td>\n",
" <td>26.512</td>\n",
" <td>27.404</td>\n",
" <td>29.96</td>\n",
" <td>54.71</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2720109084</td>\n",
" <td>3.48</td>\n",
" <td>69.60</td>\n",
" <td>0.76</td>\n",
" <td>26.148</td>\n",
" <td>27.578</td>\n",
" <td>29.31</td>\n",
" <td>55.99</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2720084708</td>\n",
" <td>3.47</td>\n",
" <td>66.71</td>\n",
" <td>0.76</td>\n",
" <td>29.055</td>\n",
" <td>26.338</td>\n",
" <td>28.58</td>\n",
" <td>53.87</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2720062721</td>\n",
" <td>3.87</td>\n",
" <td>68.78</td>\n",
" <td>0.80</td>\n",
" <td>26.542</td>\n",
" <td>27.280</td>\n",
" <td>29.97</td>\n",
" <td>54.78</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223</th>\n",
" <td>2720030490</td>\n",
" <td>4.12</td>\n",
" <td>68.85</td>\n",
" <td>0.97</td>\n",
" <td>26.055</td>\n",
" <td>27.864</td>\n",
" <td>32.94</td>\n",
" <td>51.89</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224</th>\n",
" <td>2720028633</td>\n",
" <td>3.97</td>\n",
" <td>67.04</td>\n",
" <td>0.94</td>\n",
" <td>28.043</td>\n",
" <td>27.368</td>\n",
" <td>31.88</td>\n",
" <td>51.38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>225</th>\n",
" <td>2720028634</td>\n",
" <td>4.12</td>\n",
" <td>68.42</td>\n",
" <td>0.96</td>\n",
" <td>26.493</td>\n",
" <td>27.886</td>\n",
" <td>33.16</td>\n",
" <td>52.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>226</th>\n",
" <td>2720017683</td>\n",
" <td>3.88</td>\n",
" <td>67.42</td>\n",
" <td>0.94</td>\n",
" <td>27.760</td>\n",
" <td>26.616</td>\n",
" <td>31.65</td>\n",
" <td>50.56</td>\n",
" </tr>\n",
" <tr>\n",
" <th>227</th>\n",
" <td>2720017678</td>\n",
" <td>3.81</td>\n",
" <td>66.74</td>\n",
" <td>0.92</td>\n",
" <td>28.530</td>\n",
" <td>26.688</td>\n",
" <td>31.02</td>\n",
" <td>50.82</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>228 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" 化验编号 氢Had(%) 碳Cad(%) 氮Nad(%) 氧Oad(%) 弹筒发热量Qb,adMJ/kg \\\n",
"0 2720110529 3.93 70.18 0.81 25.079 27.820 \n",
"1 2720096883 3.78 68.93 0.77 26.512 27.404 \n",
"2 2720109084 3.48 69.60 0.76 26.148 27.578 \n",
"3 2720084708 3.47 66.71 0.76 29.055 26.338 \n",
"4 2720062721 3.87 68.78 0.80 26.542 27.280 \n",
".. ... ... ... ... ... ... \n",
"223 2720030490 4.12 68.85 0.97 26.055 27.864 \n",
"224 2720028633 3.97 67.04 0.94 28.043 27.368 \n",
"225 2720028634 4.12 68.42 0.96 26.493 27.886 \n",
"226 2720017683 3.88 67.42 0.94 27.760 26.616 \n",
"227 2720017678 3.81 66.74 0.92 28.530 26.688 \n",
"\n",
" 挥发分Vad(%) 固定炭Fcad(%) \n",
"0 32.06 55.68 \n",
"1 29.96 54.71 \n",
"2 29.31 55.99 \n",
"3 28.58 53.87 \n",
"4 29.97 54.78 \n",
".. ... ... \n",
"223 32.94 51.89 \n",
"224 31.88 51.38 \n",
"225 33.16 52.00 \n",
"226 31.65 50.56 \n",
"227 31.02 50.82 \n",
"\n",
"[228 rows x 8 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_0102"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "04b177a7-2f02-4e23-8ea9-29f34cf3eafc",
"metadata": {},
"outputs": [],
"source": [
"out_cols = ['挥发分Vad(%)', '固定炭Fcad(%)']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "31169fbf-d78e-42f7-87f3-71ba3dd0979d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['挥发分Vad(%)', '固定炭Fcad(%)']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_cols"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "feaedd50-f999-45bf-b465-3d359b0c0110",
"metadata": {},
"outputs": [],
"source": [
"data = data_0102.copy()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a40bee0f-011a-4edb-80f8-4e2f40e755fd",
"metadata": {},
"outputs": [],
"source": [
"train_data = data.dropna(subset=out_cols).fillna(0)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "535d37b6-b9de-4025-ac8f-62f5bdbe2451",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-01-04 16:49:03.492957: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow.keras.backend as K"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c2318ce6-60d2-495c-91cd-67ca53609cf8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmp/ipykernel_45930/337460670.py:1: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.config.list_physical_devices('GPU')` instead.\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-01-04 16:49:04.396035: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-01-04 16:49:04.407586: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1\n",
"2024-01-04 16:49:04.465739: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal\n",
"2024-01-04 16:49:04.465795: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: zhaojh-yv621\n",
"2024-01-04 16:49:04.465807: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: zhaojh-yv621\n",
"2024-01-04 16:49:04.466010: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 520.61.5\n",
"2024-01-04 16:49:04.466041: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 520.61.5\n",
"2024-01-04 16:49:04.466045: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 520.61.5\n"
]
}
],
"source": [
"tf.test.is_gpu_available()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1c85d462-f248-4ffb-908f-eb4b20eab179",
"metadata": {},
"outputs": [],
"source": [
"class TransformerBlock(layers.Layer):\n",
" def __init__(self, embed_dim, num_heads, ff_dim, name, rate=0.1):\n",
" super().__init__()\n",
" self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, name=name)\n",
" self.ffn = keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" )\n",
" self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.dropout1 = layers.Dropout(rate)\n",
" self.dropout2 = layers.Dropout(rate)\n",
"\n",
" def call(self, inputs, training):\n",
" attn_output = self.att(inputs, inputs)\n",
" attn_output = self.dropout1(attn_output, training=training)\n",
" out1 = self.layernorm1(inputs + attn_output)\n",
" ffn_output = self.ffn(out1)\n",
" ffn_output = self.dropout2(ffn_output, training=training)\n",
" return self.layernorm2(out1 + ffn_output)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "790284a3-b9d3-4144-b481-38a7c3ecb4b9",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import Model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cd9a1ca1-d0ca-4cb5-9ef5-fd5d63576cd2",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.initializers import Constant"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9bc02f29-0fb7-420d-99a8-435eadc06e29",
"metadata": {},
"outputs": [],
"source": [
"# Custom loss layer\n",
"class CustomMultiLossLayer(layers.Layer):\n",
" def __init__(self, nb_outputs=2, **kwargs):\n",
" self.nb_outputs = nb_outputs\n",
" self.is_placeholder = True\n",
" super(CustomMultiLossLayer, self).__init__(**kwargs)\n",
" \n",
" def build(self, input_shape=None):\n",
" # initialise log_vars\n",
" self.log_vars = []\n",
" for i in range(self.nb_outputs):\n",
" self.log_vars += [self.add_weight(name='log_var' + str(i), shape=(1,),\n",
" initializer=tf.initializers.he_normal(), trainable=True)]\n",
" super(CustomMultiLossLayer, self).build(input_shape)\n",
"\n",
" def multi_loss(self, ys_true, ys_pred):\n",
" assert len(ys_true) == self.nb_outputs and len(ys_pred) == self.nb_outputs\n",
" loss = 0\n",
" for y_true, y_pred, log_var in zip(ys_true, ys_pred, self.log_vars):\n",
" mse = (y_true - y_pred) ** 2.\n",
" pre = K.exp(-log_var[0])\n",
" loss += tf.abs(tf.reduce_logsumexp(pre * mse + log_var[0], axis=-1))\n",
" return K.mean(loss)\n",
"\n",
" def call(self, inputs):\n",
" ys_true = inputs[:self.nb_outputs]\n",
" ys_pred = inputs[self.nb_outputs:]\n",
" loss = self.multi_loss(ys_true, ys_pred)\n",
" self.add_loss(loss, inputs=inputs)\n",
" # We won't actually use the output.\n",
" return K.concatenate(inputs, -1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a190207e-5a59-4813-9660-758760cf1b73",
"metadata": {},
"outputs": [],
"source": [
"num_heads, ff_dim = 3, 16"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "80f32155-e71f-4615-8d0c-01dfd04988fe",
"metadata": {},
"outputs": [],
"source": [
"def get_prediction_model():\n",
" def build_output(out, out_name):\n",
" self_block = TransformerBlock(64, num_heads, ff_dim, name=f'{out_name}_attn')\n",
" out = self_block(out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(32, activation=\"relu\")(out)\n",
" # out = layers.Dense(1, name=out_name, activation=\"sigmoid\")(out)\n",
" return out\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" x = layers.Conv1D(filters=64, kernel_size=1, activation='relu')(inputs)\n",
" # x = layers.Dropout(rate=0.1)(x)\n",
" lstm_out = layers.Bidirectional(layers.LSTM(units=64, return_sequences=True))(x)\n",
" lstm_out = layers.Dense(128, activation='relu')(lstm_out)\n",
" transformer_block = TransformerBlock(128, num_heads, ff_dim, name='first_attn')\n",
" out = transformer_block(lstm_out)\n",
" out = layers.GlobalAveragePooling1D()(out)\n",
" out = layers.Dropout(0.1)(out)\n",
" out = layers.Dense(64, activation='relu')(out)\n",
" out = K.expand_dims(out, axis=1)\n",
"\n",
" bet = build_output(out, 'vad')\n",
" mesco = build_output(out, 'fcad')\n",
"\n",
" bet = layers.Dense(1, activation='sigmoid', name='vad')(bet)\n",
" mesco = layers.Dense(1, activation='sigmoid', name='fcad')(mesco)\n",
"\n",
" model = Model(inputs=[inputs], outputs=[bet, mesco])\n",
" return model\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "264001b1-5e4a-4786-96fd-2b5c70ab3212",
"metadata": {},
"outputs": [],
"source": [
"def get_trainable_model(prediction_model):\n",
" inputs = layers.Input(shape=(1,len(feature_cols)), name='input')\n",
" bet, mesco = prediction_model(inputs)\n",
" bet_real = layers.Input(shape=(1,), name='vad_real')\n",
" mesco_real = layers.Input(shape=(1,), name='fcad_real')\n",
" out = CustomMultiLossLayer(nb_outputs=2)([bet_real, mesco_real, bet, mesco])\n",
" return Model([inputs, bet_real, mesco_real], out)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1eebdab3-1f88-48a1-b5e0-bc8787528c1b",
"metadata": {},
"outputs": [],
"source": [
"maxs = train_data.max()\n",
"mins = train_data.min()\n",
"for col in train_data.columns:\n",
" if maxs[col] - mins[col] == 0:\n",
" continue\n",
" train_data[col] = (train_data[col] - mins[col]) / (maxs[col] - mins[col])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7f27bd56-4f6b-4242-9f79-c7d6b3ee2f13",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>化验编号</th>\n",
" <th>氢Had(%)</th>\n",
" <th>碳Cad(%)</th>\n",
" <th>氮Nad(%)</th>\n",
" <th>氧Oad(%)</th>\n",
" <th>弹筒发热量Qb,adMJ/kg</th>\n",
" <th>挥发分Vad(%)</th>\n",
" <th>固定炭Fcad(%)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.996547</td>\n",
" <td>0.773973</td>\n",
" <td>0.835414</td>\n",
" <td>0.456522</td>\n",
" <td>0.171463</td>\n",
" <td>0.811249</td>\n",
" <td>0.847737</td>\n",
" <td>0.828147</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.851118</td>\n",
" <td>0.671233</td>\n",
" <td>0.799943</td>\n",
" <td>0.369565</td>\n",
" <td>0.210254</td>\n",
" <td>0.782038</td>\n",
" <td>0.674897</td>\n",
" <td>0.794606</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.981147</td>\n",
" <td>0.465753</td>\n",
" <td>0.818956</td>\n",
" <td>0.347826</td>\n",
" <td>0.200401</td>\n",
" <td>0.794256</td>\n",
" <td>0.621399</td>\n",
" <td>0.838866</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.721367</td>\n",
" <td>0.458904</td>\n",
" <td>0.736947</td>\n",
" <td>0.347826</td>\n",
" <td>0.279094</td>\n",
" <td>0.707183</td>\n",
" <td>0.561317</td>\n",
" <td>0.765560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.487046</td>\n",
" <td>0.732877</td>\n",
" <td>0.795687</td>\n",
" <td>0.434783</td>\n",
" <td>0.211066</td>\n",
" <td>0.773331</td>\n",
" <td>0.675720</td>\n",
" <td>0.797026</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223</th>\n",
" <td>0.143553</td>\n",
" <td>0.904110</td>\n",
" <td>0.797673</td>\n",
" <td>0.804348</td>\n",
" <td>0.197883</td>\n",
" <td>0.814339</td>\n",
" <td>0.920165</td>\n",
" <td>0.697095</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224</th>\n",
" <td>0.123762</td>\n",
" <td>0.801370</td>\n",
" <td>0.746311</td>\n",
" <td>0.739130</td>\n",
" <td>0.251699</td>\n",
" <td>0.779510</td>\n",
" <td>0.832922</td>\n",
" <td>0.679461</td>\n",
" </tr>\n",
" <tr>\n",
" <th>225</th>\n",
" <td>0.123773</td>\n",
" <td>0.904110</td>\n",
" <td>0.785471</td>\n",
" <td>0.782609</td>\n",
" <td>0.209740</td>\n",
" <td>0.815884</td>\n",
" <td>0.938272</td>\n",
" <td>0.700899</td>\n",
" </tr>\n",
" <tr>\n",
" <th>226</th>\n",
" <td>0.007066</td>\n",
" <td>0.739726</td>\n",
" <td>0.757094</td>\n",
" <td>0.739130</td>\n",
" <td>0.244038</td>\n",
" <td>0.726705</td>\n",
" <td>0.813992</td>\n",
" <td>0.651107</td>\n",
" </tr>\n",
" <tr>\n",
" <th>227</th>\n",
" <td>0.007012</td>\n",
" <td>0.691781</td>\n",
" <td>0.737798</td>\n",
" <td>0.695652</td>\n",
" <td>0.264882</td>\n",
" <td>0.731760</td>\n",
" <td>0.762140</td>\n",
" <td>0.660097</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>228 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" 化验编号 氢Had(%) 碳Cad(%) 氮Nad(%) 氧Oad(%) 弹筒发热量Qb,adMJ/kg \\\n",
"0 0.996547 0.773973 0.835414 0.456522 0.171463 0.811249 \n",
"1 0.851118 0.671233 0.799943 0.369565 0.210254 0.782038 \n",
"2 0.981147 0.465753 0.818956 0.347826 0.200401 0.794256 \n",
"3 0.721367 0.458904 0.736947 0.347826 0.279094 0.707183 \n",
"4 0.487046 0.732877 0.795687 0.434783 0.211066 0.773331 \n",
".. ... ... ... ... ... ... \n",
"223 0.143553 0.904110 0.797673 0.804348 0.197883 0.814339 \n",
"224 0.123762 0.801370 0.746311 0.739130 0.251699 0.779510 \n",
"225 0.123773 0.904110 0.785471 0.782609 0.209740 0.815884 \n",
"226 0.007066 0.739726 0.757094 0.739130 0.244038 0.726705 \n",
"227 0.007012 0.691781 0.737798 0.695652 0.264882 0.731760 \n",
"\n",
" 挥发分Vad(%) 固定炭Fcad(%) \n",
"0 0.847737 0.828147 \n",
"1 0.674897 0.794606 \n",
"2 0.621399 0.838866 \n",
"3 0.561317 0.765560 \n",
"4 0.675720 0.797026 \n",
".. ... ... \n",
"223 0.920165 0.697095 \n",
"224 0.832922 0.679461 \n",
"225 0.938272 0.700899 \n",
"226 0.813992 0.651107 \n",
"227 0.762140 0.660097 \n",
"\n",
"[228 rows x 8 columns]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "baf45a3d-dc01-44fc-9f0b-456964ac2cdb",
"metadata": {},
"outputs": [],
"source": [
"# feature_cols = [x for x in train_data.columns if x not in out_cols and '第二次' not in x]\n",
"feature_cols = [x for x in train_data.columns if x not in out_cols]\n",
"use_cols = feature_cols + out_cols"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f2d27538-d2bc-4202-b0cf-d3e0949b4686",
"metadata": {},
"outputs": [],
"source": [
"use_data = train_data.copy()\n",
"for col in use_cols:\n",
" use_data[col] = use_data[col].astype('float32')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "54c1df2c-c297-4b8d-be8a-3a99cff22545",
"metadata": {},
"outputs": [],
"source": [
"train, valid = train_test_split(use_data[use_cols], test_size=0.3, random_state=42, shuffle=True)\n",
"valid, test = train_test_split(valid, test_size=0.3, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "e7a914da-b9c2-40d9-96e0-459b0888adba",
"metadata": {},
"outputs": [],
"source": [
"prediction_model = get_prediction_model()\n",
"trainable_model = get_trainable_model(prediction_model)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "4f832a1e-48e2-4467-b381-35b9d2f1271a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_3\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input (InputLayer) [(None, 1, 6)] 0 \n",
"__________________________________________________________________________________________________\n",
"conv1d_3 (Conv1D) (None, 1, 64) 448 input[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_3 (Bidirectional) (None, 1, 128) 66048 conv1d_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_30 (Dense) (None, 1, 128) 16512 bidirectional_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_9 (Transforme (None, 1, 128) 202640 dense_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_9 (Glo (None, 128) 0 transformer_block_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_29 (Dropout) (None, 128) 0 global_average_pooling1d_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_33 (Dense) (None, 64) 8256 dropout_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf.expand_dims_3 (TFOpLambda) (None, 1, 64) 0 dense_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_10 (Transform (None, 1, 64) 52176 tf.expand_dims_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"transformer_block_11 (Transform (None, 1, 64) 52176 tf.expand_dims_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_10 (Gl (None, 64) 0 transformer_block_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_average_pooling1d_11 (Gl (None, 64) 0 transformer_block_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_32 (Dropout) (None, 64) 0 global_average_pooling1d_10[0][0]\n",
"__________________________________________________________________________________________________\n",
"dropout_35 (Dropout) (None, 64) 0 global_average_pooling1d_11[0][0]\n",
"__________________________________________________________________________________________________\n",
"dense_36 (Dense) (None, 32) 2080 dropout_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_39 (Dense) (None, 32) 2080 dropout_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"vad (Dense) (None, 1) 33 dense_36[0][0] \n",
"__________________________________________________________________________________________________\n",
"fcad (Dense) (None, 1) 33 dense_39[0][0] \n",
"==================================================================================================\n",
"Total params: 402,482\n",
"Trainable params: 402,482\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"prediction_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "9289f452-a5a4-40c4-b942-f6cb2e348548",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras import optimizers\n",
"from tensorflow.python.keras.utils.vis_utils import plot_model"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "2494ef5a-5b2b-4f11-b6cd-dc39503c9106",
"metadata": {},
"outputs": [],
"source": [
"X = np.expand_dims(train[feature_cols].values, axis=1)\n",
"Y = [x for x in train[out_cols].values.T]\n",
"Y_valid = [x for x in valid[out_cols].values.T]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "9a62dea1-4f05-411b-9756-a91623580581",
"metadata": {},
"outputs": [],
"source": [
"from keras.callbacks import ReduceLROnPlateau\n",
"reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "cf869e4d-0fce-45a2-afff-46fd9b30fd1c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/120\n",
"20/20 [==============================] - 5s 59ms/step - loss: 1.8316 - val_loss: 1.8096\n",
"Epoch 2/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 1.7903 - val_loss: 1.7691\n",
"Epoch 3/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 1.7506 - val_loss: 1.7307\n",
"Epoch 4/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 1.7110 - val_loss: 1.6914\n",
"Epoch 5/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 1.6711 - val_loss: 1.6497\n",
"Epoch 6/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 1.6314 - val_loss: 1.6098\n",
"Epoch 7/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 1.5909 - val_loss: 1.5695\n",
"Epoch 8/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 1.5506 - val_loss: 1.5296\n",
"Epoch 9/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 1.5109 - val_loss: 1.4891\n",
"Epoch 10/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 1.4706 - val_loss: 1.4500\n",
"Epoch 11/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 1.4306 - val_loss: 1.4104\n",
"Epoch 12/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 1.3907 - val_loss: 1.3746\n",
"Epoch 13/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 1.3508 - val_loss: 1.3296\n",
"Epoch 14/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 1.3106 - val_loss: 1.2895\n",
"Epoch 15/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 1.2706 - val_loss: 1.2515\n",
"Epoch 16/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 1.2315 - val_loss: 1.2104\n",
"Epoch 17/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 1.1908 - val_loss: 1.1702\n",
"Epoch 18/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 1.1508 - val_loss: 1.1320\n",
"Epoch 19/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 1.1114 - val_loss: 1.0917\n",
"Epoch 20/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 1.0718 - val_loss: 1.0513\n",
"Epoch 21/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 1.0315 - val_loss: 1.0178\n",
"Epoch 22/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 0.9918 - val_loss: 0.9704\n",
"Epoch 23/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.9511 - val_loss: 0.9321\n",
"Epoch 24/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.9114 - val_loss: 0.8913\n",
"Epoch 25/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.8718 - val_loss: 0.8520\n",
"Epoch 26/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.8314 - val_loss: 0.8124\n",
"Epoch 27/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.7922 - val_loss: 0.7727\n",
"Epoch 28/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.7519 - val_loss: 0.7307\n",
"Epoch 29/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.7119 - val_loss: 0.6932\n",
"Epoch 30/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.6720 - val_loss: 0.6531\n",
"Epoch 31/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.6336 - val_loss: 0.6155\n",
"Epoch 32/120\n",
"20/20 [==============================] - 1s 26ms/step - loss: 0.5931 - val_loss: 0.5738\n",
"Epoch 33/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.5517 - val_loss: 0.5324\n",
"Epoch 34/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.5135 - val_loss: 0.4943\n",
"Epoch 35/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.4724 - val_loss: 0.4602\n",
"Epoch 36/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.4326 - val_loss: 0.4126\n",
"Epoch 37/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.3947 - val_loss: 0.3758\n",
"Epoch 38/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.3558 - val_loss: 0.3350\n",
"Epoch 39/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.3154 - val_loss: 0.3031\n",
"Epoch 40/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.2771 - val_loss: 0.2592\n",
"Epoch 41/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.2459 - val_loss: 0.2370\n",
"Epoch 42/120\n",
"20/20 [==============================] - 1s 27ms/step - loss: 0.2267 - val_loss: 0.2210\n",
"Epoch 43/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 0.2050 - val_loss: 0.1947\n",
"Epoch 44/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.1840 - val_loss: 0.1728\n",
"Epoch 45/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.1628 - val_loss: 0.1533\n",
"Epoch 46/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.1430 - val_loss: 0.1322\n",
"Epoch 47/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.1230 - val_loss: 0.1147\n",
"Epoch 48/120\n",
"20/20 [==============================] - 1s 24ms/step - loss: 0.1026 - val_loss: 0.0940\n",
"Epoch 49/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0830 - val_loss: 0.0750\n",
"Epoch 50/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0639 - val_loss: 0.0529\n",
"Epoch 51/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.0436 - val_loss: 0.0352\n",
"Epoch 52/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 0.0241 - val_loss: 0.0162\n",
"Epoch 53/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0092 - val_loss: 0.0084\n",
"Epoch 54/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0067 - val_loss: 0.0074\n",
"Epoch 55/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0080 - val_loss: 0.0071\n",
"Epoch 56/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0070 - val_loss: 0.0063\n",
"Epoch 57/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0062 - val_loss: 0.0076\n",
"Epoch 58/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0056 - val_loss: 0.0048\n",
"Epoch 59/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0050 - val_loss: 0.0071\n",
"Epoch 60/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0057 - val_loss: 0.0054\n",
"Epoch 61/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0044 - val_loss: 0.0092\n",
"Epoch 62/120\n",
"20/20 [==============================] - 1s 26ms/step - loss: 0.0068 - val_loss: 0.0070\n",
"Epoch 63/120\n",
"20/20 [==============================] - 1s 24ms/step - loss: 0.0059 - val_loss: 0.0065\n",
"Epoch 64/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0055 - val_loss: 0.0060\n",
"Epoch 65/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0053 - val_loss: 0.0056\n",
"Epoch 66/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0058 - val_loss: 0.0077\n",
"Epoch 67/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0051 - val_loss: 0.0054\n",
"Epoch 68/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0047 - val_loss: 0.0048\n",
"Epoch 69/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0041 - val_loss: 0.0048\n",
"Epoch 70/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0037 - val_loss: 0.0049\n",
"Epoch 71/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0041 - val_loss: 0.0049\n",
"Epoch 72/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0036 - val_loss: 0.0049\n",
"Epoch 73/120\n",
"20/20 [==============================] - 1s 24ms/step - loss: 0.0038 - val_loss: 0.0048\n",
"Epoch 74/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0037 - val_loss: 0.0050\n",
"Epoch 75/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0037 - val_loss: 0.0048\n",
"Epoch 76/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0038 - val_loss: 0.0048\n",
"Epoch 77/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0037 - val_loss: 0.0048\n",
"Epoch 78/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0038 - val_loss: 0.0048\n",
"Epoch 79/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0036 - val_loss: 0.0048\n",
"Epoch 80/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.0034 - val_loss: 0.0048\n",
"Epoch 81/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 82/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 83/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0033 - val_loss: 0.0047\n",
"Epoch 84/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 85/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 86/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 87/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 88/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 89/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 90/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 91/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 92/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 93/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 94/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 95/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 96/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0032 - val_loss: 0.0047\n",
"Epoch 97/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 98/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0038 - val_loss: 0.0047\n",
"Epoch 99/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0033 - val_loss: 0.0047\n",
"Epoch 100/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 101/120\n",
"20/20 [==============================] - 1s 26ms/step - loss: 0.0033 - val_loss: 0.0047\n",
"Epoch 102/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 103/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 104/120\n",
"20/20 [==============================] - 0s 22ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 105/120\n",
"20/20 [==============================] - 0s 23ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 106/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 107/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 108/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0034 - val_loss: 0.0047\n",
"Epoch 109/120\n",
"20/20 [==============================] - 1s 24ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 110/120\n",
"20/20 [==============================] - 0s 24ms/step - loss: 0.0038 - val_loss: 0.0047\n",
"Epoch 111/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 112/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 113/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 114/120\n",
"20/20 [==============================] - 0s 20ms/step - loss: 0.0035 - val_loss: 0.0047\n",
"Epoch 115/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0033 - val_loss: 0.0047\n",
"Epoch 116/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 117/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0032 - val_loss: 0.0047\n",
"Epoch 118/120\n",
"20/20 [==============================] - 0s 26ms/step - loss: 0.0036 - val_loss: 0.0047\n",
"Epoch 119/120\n",
"20/20 [==============================] - 0s 25ms/step - loss: 0.0037 - val_loss: 0.0047\n",
"Epoch 120/120\n",
"20/20 [==============================] - 0s 21ms/step - loss: 0.0036 - val_loss: 0.0047\n"
]
}
],
"source": [
"trainable_model.compile(optimizer='adam', loss=None)\n",
"hist = trainable_model.fit([X, Y[0], Y[1]], epochs=120, batch_size=8, verbose=1, \n",
" validation_data=[np.expand_dims(valid[feature_cols].values, axis=1), Y_valid[0], Y_valid[1]],\n",
" callbacks=[reduce_lr]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "67bfbe88-5f2c-4659-b2dc-eb9f1b824d04",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[0.73740077],\n",
" [0.89292204],\n",
" [0.7599046 ],\n",
" [0.67802393],\n",
" [0.6815233 ],\n",
" [0.88627005],\n",
" [0.6121343 ],\n",
" [0.7072234 ],\n",
" [0.8561135 ],\n",
" [0.52762157],\n",
" [0.8325021 ],\n",
" [0.50241977],\n",
" [0.8242289 ],\n",
" [0.68957335],\n",
" [0.6980361 ],\n",
" [0.82116604],\n",
" [0.8566438 ],\n",
" [0.53687835],\n",
" [0.56832707],\n",
" [0.78476715],\n",
" [0.85638577]], dtype=float32),\n",
" array([[0.68600863],\n",
" [0.78454906],\n",
" [0.8179163 ],\n",
" [0.94351083],\n",
" [0.86383885],\n",
" [0.69705516],\n",
" [0.6913491 ],\n",
" [0.80277354],\n",
" [0.93557894],\n",
" [0.82278305],\n",
" [0.82674253],\n",
" [0.93518937],\n",
" [0.8094449 ],\n",
" [0.9206344 ],\n",
" [0.7747319 ],\n",
" [0.9137207 ],\n",
" [0.9491073 ],\n",
" [0.93225 ],\n",
" [0.6185102 ],\n",
" [0.8867341 ],\n",
" [0.82890105]], dtype=float32)]"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rst = prediction_model.predict(np.expand_dims(test[feature_cols], axis=1))\n",
"rst"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "7de501e9-05a2-424c-a5f4-85d43ad37592",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.9991559102070927, 0.9998196796918477]"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[np.exp(K.get_value(log_var[0]))**0.5 for log_var in trainable_model.layers[-1].log_vars]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b0d5d8ad-aadd-4218-b5b7-9691a2d3eeef",
"metadata": {},
"outputs": [],
"source": [
"pred_rst = pd.DataFrame.from_records(np.squeeze(np.asarray(rst), axis=2).T, columns=out_cols)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "0a2bcb45-da86-471b-a61d-314e29430d6a",
"metadata": {},
"outputs": [],
"source": [
"real_rst = test[out_cols].copy()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "e124f7c0-fdd5-43b9-b649-ff7d9dd59641",
"metadata": {},
"outputs": [],
"source": [
"for col in out_cols:\n",
" pred_rst[col] = pred_rst[col] * (maxs[col] - mins[col]) + mins[col]\n",
" real_rst[col] = real_rst[col] * (maxs[col] - mins[col]) + mins[col]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "5c69d03b-34fd-4dbf-aec6-c15093bb22ab",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['挥发分Vad(%)', '固定炭Fcad(%)'], dtype='object')"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"real_rst.columns"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "21739f82-d82a-4bde-8537-9504b68a96d5",
"metadata": {},
"outputs": [],
"source": [
"y_pred_vad = pred_rst['挥发分Vad(%)'].values.reshape(-1,)\n",
"y_pred_fcad = pred_rst['固定炭Fcad(%)'].values.reshape(-1,)\n",
"y_true_vad = real_rst['挥发分Vad(%)'].values.reshape(-1,)\n",
"y_true_fcad = real_rst['固定炭Fcad(%)'].values.reshape(-1,)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "26ea6cfa-efad-443c-9dd9-844f8be42b91",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "28072e7c-c9d5-4ff6-940d-e94ae879afc9",
"metadata": {},
"outputs": [],
"source": [
"def print_eva(y_true, y_pred, tp):\n",
" MSE = mean_squared_error(y_true, y_pred)\n",
" RMSE = np.sqrt(MSE)\n",
" MAE = mean_absolute_error(y_true, y_pred)\n",
" MAPE = mean_absolute_percentage_error(y_true, y_pred)\n",
" R_2 = r2_score(y_true, y_pred)\n",
" print(f\"COL: {tp}, MSE: {format(MSE, '.2E')}\", end=',')\n",
" print(f'RMSE: {round(RMSE, 3)}', end=',')\n",
" print(f'MAPE: {round(MAPE * 100, 3)} %', end=',')\n",
" print(f'MAE: {round(MAE, 3)}', end=',')\n",
" print(f'R_2: {round(R_2, 3)}')\n",
" return [MSE, RMSE, MAE, MAPE, R_2]"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "4ec4caa9-7c46-4fc8-a94b-cb659e924304",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"COL: 挥发分Vad, MSE: 3.35E-01,RMSE: 0.579,MAPE: 1.639 %,MAE: 0.504,R_2: 0.87\n",
"COL: 固定炭Fcad, MSE: 1.11E+00,RMSE: 1.055,MAPE: 1.497 %,MAE: 0.814,R_2: 0.876\n"
]
}
],
"source": [
"pm25_eva = print_eva(y_true_vad, y_pred_vad, tp='挥发分Vad')\n",
"pm10_eva = print_eva(y_true_fcad, y_pred_fcad, tp='固定炭Fcad')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac4a4339-ec7d-4266-8197-5276c2395288",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f15cbb91-1ce7-4fb0-979a-a4bdc452a1ec",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}