ICEEMDAN-Solar_power-forecast/ConvBigru_IRPE_Attention特定数...

956 lines
172 KiB
Plaintext
Raw Permalink Normal View History

2024-08-12 07:42:30 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from math import sqrt\n",
"from numpy import concatenate\n",
"from matplotlib import pyplot\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error\n",
"from tensorflow.keras import Sequential\n",
"\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import LSTM\n",
"from tensorflow.keras.layers import Dropout\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这段代码是一个函数 time_series_to_supervised它用于将时间序列数据转换为监督学习问题的数据集。下面是该函数的各个部分的含义\n",
"\n",
"data: 输入的时间序列数据可以是列表或2D NumPy数组。\n",
"n_in: 作为输入的滞后观察数即用多少个时间步的观察值作为输入。默认值为96表示使用前96个时间步的观察值作为输入。\n",
"n_out: 作为输出的观测数量即预测多少个时间步的观察值。默认值为1表示预测未来1个时间步的观察值。\n",
"dropnan: 布尔值表示是否删除具有NaN值的行。默认为True即删除具有NaN值的行。\n",
"函数首先检查输入数据的维度并初始化一些变量。然后它创建一个新的DataFrame对象 df 来存储输入数据,并保存原始的列名。接着,它创建了两个空列表 cols 和 names用于存储新的特征列和列名。\n",
"\n",
"接下来,函数开始构建特征列和对应的列名。首先,它将原始的观察序列添加到 cols 列表中,并将其列名添加到 names 列表中。然后,它依次将滞后的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t-滞后时间)。这样就创建了输入特征的部分。\n",
"\n",
"接着,函数开始构建输出特征的部分。它依次将未来的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t+未来时间)。\n",
"\n",
"最后函数将所有的特征列拼接在一起构成一个新的DataFrame对象 agg。如果 dropnan 参数为True则删除具有NaN值的行。最后函数返回处理后的数据集 agg。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def time_series_to_supervised(data, n_in=96, n_out=1,dropnan=True):\n",
" \"\"\"\n",
" :param data:作为列表或2D NumPy数组的观察序列。需要。\n",
" :param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。\n",
" :param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。\n",
" :param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。\n",
" :return:\n",
" \"\"\"\n",
" n_vars = 1 if type(data) is list else data.shape[1]\n",
" df = pd.DataFrame(data)\n",
" origNames = df.columns\n",
" cols, names = list(), list()\n",
" cols.append(df.shift(0))\n",
" names += [('%s' % origNames[j]) for j in range(n_vars)]\n",
" n_in = max(0, n_in)\n",
" for i in range(n_in, 0, -1):\n",
" time = '(t-%d)' % i\n",
" cols.append(df.shift(i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" n_out = max(n_out, 0)\n",
" for i in range(1, n_out+1):\n",
" time = '(t+%d)' % i\n",
" cols.append(df.shift(-i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" agg = pd.concat(cols, axis=1)\n",
" agg.columns = names\n",
" if dropnan:\n",
" agg.dropna(inplace=True)\n",
" return agg"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall Power\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 0.0\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 0.0\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 0.0\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 0.0\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 0.0\n",
"(104256, 6)\n"
]
}
],
"source": [
"# 加载数据\n",
"path1 = r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\data6.csv\"#数据所在路径\n",
"#我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"datas1 = pd.DataFrame(pd.read_csv(path1))\n",
"#我只取了data表里的第3、23、16、17、18、19、20、21、27列如果取全部列的话这一行可以去掉\n",
"# data1 = datas1.iloc[:,np.r_[3,23,16:22,27]]\n",
"data1=datas1.interpolate()\n",
"values1 = data1.values\n",
"print(data1.head())\n",
"print(data1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(104256, 6)\n"
]
}
],
"source": [
"# 使用MinMaxScaler进行归一化\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"scaledData1 = scaler.fit_transform(data1)\n",
"print(scaledData1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 1 2 3 4 5 0(t-96) \\\n",
"96 0.555631 0.349673 0.190042 0.040558 0.0 0.236302 0.490360 \n",
"97 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 0.489088 \n",
"98 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 0.485824 \n",
"99 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 0.482997 \n",
"100 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 0.482632 \n",
"\n",
" 1(t-96) 2(t-96) 3(t-96) ... 2(t-1) 3(t-1) 4(t-1) 5(t-1) \\\n",
"96 0.369105 0.002088 0.002013 ... 0.166009 0.036794 0.0 0.214129 \n",
"97 0.364859 0.002061 0.001839 ... 0.190042 0.040558 0.0 0.236302 \n",
"98 0.364886 0.001973 0.001839 ... 0.211335 0.044613 0.0 0.258204 \n",
"99 0.365615 0.001950 0.001697 ... 0.229657 0.047549 0.0 0.279860 \n",
"100 0.361965 0.001679 0.001167 ... 0.247775 0.053347 0.0 0.301336 \n",
"\n",
" 0(t+1) 1(t+1) 2(t+1) 3(t+1) 4(t+1) 5(t+1) \n",
"96 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 \n",
"97 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 \n",
"98 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 \n",
"99 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 \n",
"100 0.590772 0.258790 0.282900 0.060958 0.0 0.343360 \n",
"\n",
"[5 rows x 588 columns]\n"
]
}
],
"source": [
"n_steps_in =96 #历史时间长度\n",
"n_steps_out=1#预测时间长度\n",
"processedData1 = time_series_to_supervised(scaledData1,n_steps_in,n_steps_out)\n",
"print(processedData1.head())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data_x = processedData1.loc[:,'0(t-96)':'5(t-1)']#去除power剩下的做标签列\n",
"data_y = processedData1.loc[:,'5']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"冒号\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159, 576)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_x.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"96 0.236302\n",
"97 0.258204\n",
"98 0.279860\n",
"99 0.301336\n",
"100 0.322851\n",
" ... \n",
"104250 0.000000\n",
"104251 0.000000\n",
"104252 0.000000\n",
"104253 0.000000\n",
"104254 0.000000\n",
"Name: 5, Length: 104159, dtype: float64"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159,)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(93743, 96, 6) (93743,) (8854, 96, 6) (8854,) (1562, 96, 6) (1562,)\n"
]
}
],
"source": [
"# 计算训练集、验证集和测试集的大小\n",
"train_size = int(len(data_x) * 0.90)\n",
"test_size = int(len(data_x) * 0.015)\n",
"val_size = len(data_x) - train_size - test_size\n",
"\n",
"# 计算训练集、验证集和测试集的索引范围\n",
"train_indices = range(train_size)\n",
"val_indices = range(train_size, train_size + val_size)\n",
"test_indices = range(train_size + val_size, len(data_x))\n",
"\n",
"# 根据索引范围划分数据集\n",
"train_X1 = data_x.iloc[train_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"val_X1 = data_x.iloc[val_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"test_X1 = data_x.iloc[test_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"train_y = data_y.iloc[train_indices].values\n",
"val_y = data_y.iloc[val_indices].values\n",
"test_y = data_y.iloc[test_indices].values\n",
"\n",
"# reshape input to be 3D [samples, timesteps, features]\n",
"train_X = train_X1.reshape((train_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"val_X = val_X1.reshape((val_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"test_X = test_X1.reshape((test_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"\n",
"print(train_X.shape, train_y.shape, val_X.shape, val_y.shape, test_X.shape, test_y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(93743, 96, 6)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_X1.shape"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_4\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_4\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">6</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv1D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">832</span> │ input_layer_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv1d_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling1D</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">49,920</span> │ max_pooling1d_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,304</span> │ bidirectional_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AttentionWithImpr…</span> │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, │ │ bidirectional_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)] │ │ bidirectional_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ attention_with_i… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePool…</span> │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_24 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">129</span> │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d_4 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m832\u001b[0m │ input_layer_4[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling1D\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,920\u001b[0m │ max_pooling1d_4[\u001b[38;5;34m…\u001b[0m │\n",
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m66,304\u001b[0m │ bidirectional_4[\u001b[38;5;34m…\u001b[0m │\n",
"│ (\u001b[38;5;33mAttentionWithImpr…\u001b[0m │ \u001b[38;5;34m128\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, │ │ bidirectional_4[\u001b[38;5;34m…\u001b[0m │\n",
"│ │ \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ bidirectional_4[\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention_with_i… │\n",
"│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Conv1D, Bidirectional, GlobalAveragePooling1D, Dense, GRU, MaxPooling1D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.initializers import RandomUniform\n",
"class AttentionWithImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, max_len=5000):\n",
" super(AttentionWithImproveRelativePositionEncoding, self).__init__()\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
" self.max_len = max_len\n",
" self.wq = tf.keras.layers.Dense(d_model)\n",
" self.wk = tf.keras.layers.Dense(d_model)\n",
" self.wv = tf.keras.layers.Dense(d_model)\n",
" self.dense = tf.keras.layers.Dense(d_model)\n",
" self.position_encoding = ImproveRelativePositionEncoding(d_model)\n",
"\n",
" def call(self, v, k, q, mask):\n",
" batch_size = tf.shape(q)[0]\n",
" q = self.wq(q)\n",
" k = self.wk(k)\n",
" v = self.wv(v)\n",
"\n",
" # 添加位置编码\n",
" k += self.position_encoding (k)\n",
" q += self.position_encoding (q)\n",
"\n",
" q = self.split_heads(q, batch_size)\n",
" k = self.split_heads(k, batch_size)\n",
" v = self.split_heads(v, batch_size)\n",
"\n",
" scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)\n",
" scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])\n",
" concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))\n",
" output = self.dense(concat_attention)\n",
" return output, attention_weights\n",
"\n",
" def split_heads(self, x, batch_size):\n",
" x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_model // self.num_heads))\n",
" return tf.transpose(x, perm=[0, 2, 1, 3])\n",
"\n",
" def scaled_dot_product_attention(self, q, k, v, mask):\n",
" matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
" dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
" scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
"\n",
" if mask is not None:\n",
" scaled_attention_logits += (mask * -1e9)\n",
"\n",
" attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
" output = tf.matmul(attention_weights, v)\n",
" return output, attention_weights\n",
"\n",
"class ImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, max_len=5000):\n",
" super(ImproveRelativePositionEncoding, self).__init__()\n",
" self.max_len = max_len\n",
" self.d_model = d_model\n",
" # 引入可变化的参数u和v进行线性变化\n",
" self.u = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" self.v = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" def call(self, inputs):\n",
" seq_length = inputs.shape[1]\n",
" pos_encoding = self.relative_positional_encoding(seq_length, self.d_model)\n",
" \n",
" # 调整原始的相对位置编码公式将u和v参数融入其中\n",
" pe_with_params = pos_encoding * self.u+ pos_encoding * self.v\n",
" return inputs + pe_with_params\n",
"\n",
" def relative_positional_encoding(self, position, d_model):\n",
" pos = tf.range(position, dtype=tf.float32)\n",
" i = tf.range(d_model, dtype=tf.float32)\n",
" \n",
" angles = 1 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(d_model, tf.float32))\n",
" angle_rads = tf.einsum('i,j->ij', pos, angles)\n",
" #保留了sinous机制\n",
" # Apply sin to even indices; 2i\n",
" angle_rads_sin = tf.sin(angle_rads[:, 0::2])\n",
" # Apply cos to odd indices; 2i+1\n",
" angle_rads_cos = tf.cos(angle_rads[:, 1::2])\n",
"\n",
" pos_encoding = tf.stack([angle_rads_sin, angle_rads_cos], axis=2)\n",
" pos_encoding = tf.reshape(pos_encoding, [1, position, d_model])\n",
"\n",
" return pos_encoding\n",
"\n",
"\n",
"\n",
"def PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads):\n",
" inputs = Input(shape=input_shape)\n",
" # CNN layer\n",
" cnn_layer = Conv1D(filters=64, kernel_size=2, activation='relu')(inputs)\n",
" cnn_layer = MaxPooling1D(pool_size=1)(cnn_layer)\n",
" gru_output = Bidirectional(GRU(gru_units, return_sequences=True))(cnn_layer)\n",
" \n",
" # Apply Self-Attention\n",
" self_attention =AttentionWithImproveRelativePositionEncoding(d_model=gru_units*2, num_heads=num_heads)\n",
" gru_output, _ = self_attention(gru_output, gru_output, gru_output, mask=None)\n",
" \n",
" pool1 = GlobalAveragePooling1D()(gru_output)\n",
" output = Dense(1)(pool1)\n",
" \n",
" return Model(inputs=inputs, outputs=output)\n",
"\n",
"\n",
"input_shape = (96, 6)\n",
"gru_units = 64\n",
"num_heads = 8\n",
"\n",
"# Create model\n",
"model = PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads)\n",
"model.compile(optimizer='adam', loss='mse')\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 71ms/step - loss: 0.0198 - val_loss: 0.0016\n",
"Epoch 2/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0016 - val_loss: 0.0015\n",
"Epoch 3/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m108s\u001b[0m 74ms/step - loss: 0.0015 - val_loss: 0.0015\n",
"Epoch 4/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0015 - val_loss: 0.0014\n",
"Epoch 5/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 73ms/step - loss: 0.0014 - val_loss: 0.0016\n",
"Epoch 6/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 71ms/step - loss: 0.0014 - val_loss: 0.0015\n",
"Epoch 7/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 71ms/step - loss: 0.0014 - val_loss: 0.0014\n",
"Epoch 8/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0013 - val_loss: 0.0014\n",
"Epoch 9/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0013 - val_loss: 0.0014\n",
"Epoch 10/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 72ms/step - loss: 0.0013 - val_loss: 0.0015\n",
"Epoch 11/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 71ms/step - loss: 0.0013 - val_loss: 0.0014\n",
"Epoch 12/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 72ms/step - loss: 0.0013 - val_loss: 0.0015\n",
"Epoch 13/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0013 - val_loss: 0.0014\n",
"Epoch 14/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m108s\u001b[0m 74ms/step - loss: 0.0012 - val_loss: 0.0014\n",
"Epoch 15/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 73ms/step - loss: 0.0013 - val_loss: 0.0014\n",
"Epoch 16/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 71ms/step - loss: 0.0013 - val_loss: 0.0013\n",
"Epoch 17/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 72ms/step - loss: 0.0013 - val_loss: 0.0013\n",
"Epoch 18/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 72ms/step - loss: 0.0012 - val_loss: 0.0014\n",
"Epoch 19/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 69ms/step - loss: 0.0012 - val_loss: 0.0013\n",
"Epoch 20/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 70ms/step - loss: 0.0012 - val_loss: 0.0014\n",
"Epoch 21/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 70ms/step - loss: 0.0012 - val_loss: 0.0014\n",
"Epoch 22/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 70ms/step - loss: 0.0011 - val_loss: 0.0014\n",
"Epoch 23/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 69ms/step - loss: 0.0012 - val_loss: 0.0018\n",
"Epoch 24/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m101s\u001b[0m 69ms/step - loss: 0.0012 - val_loss: 0.0014\n",
"Epoch 25/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 70ms/step - loss: 0.0011 - val_loss: 0.0014\n",
"Epoch 26/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 70ms/step - loss: 0.0012 - val_loss: 0.0015\n",
"Epoch 27/100\n",
"\u001b[1m1465/1465\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m97s\u001b[0m 66ms/step - loss: 0.0012 - val_loss: 0.0015\n"
]
}
],
"source": [
"# Compile and train the model\n",
"model.compile(optimizer='adam', loss='mean_squared_error')\n",
"from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
"\n",
"# 定义早停机制\n",
"early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')\n",
"\n",
"# 拟合模型,并添加早停机制和模型检查点\n",
"history = model.fit(train_X, train_y, epochs=100, batch_size=64, validation_data=(val_X, val_y), \n",
" callbacks=[early_stopping])\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m49/49\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step\n"
]
}
],
"source": [
"# 预测\n",
"lstm_pred = model.predict(test_X)\n",
"# 将预测结果的形状修改为与原始数据相同的形状\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"test_y_pre=test_y"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKLUlEQVR4nO3de3zT9aE//lfuaZs0UApNWwqtUEDGpQpSy5jorBZlzqpzyJxcfhzQjXnAikwYFHVod0AcopzD8exMtp0hyFeGlzEmq3epRW46pnITLArpBWzThjbXz++PT/JJA2lp2s8nn7a8no9HHkk+eeeTd0JoX31fNYIgCCAiIiLq4bRqV4CIiIhIDgw1RERE1Csw1BAREVGvwFBDREREvQJDDREREfUKDDVERETUKzDUEBERUa/AUENERES9gl7tCsRLIBDA6dOnYbVaodFo1K4OERERdYAgCGhsbERGRga02vbbYi6bUHP69GlkZWWpXQ0iIiLqhFOnTmHgwIHtlrlsQo3VagUgfijJyckq14aIiIg6wul0IisrS/o93p7LJtSEupySk5MZaoiIiHqYjgwd4UBhIiIi6hUYaoiIiKhXYKghIiKiXuGyGVNDRESkFEEQ4PP54Pf71a5Kj6PT6aDX62VZboWhhoiIqAs8Hg/OnDmD8+fPq12VHisxMRHp6ekwGo1dOg9DDRERUScFAgGcOHECOp0OGRkZMBqNXOA1BoIgwOPxoLa2FidOnEBubu4lF9hrD0MNERFRJ3k8HgQCAWRlZSExMVHt6vRICQkJMBgM+Oqrr+DxeGA2mzt9Lg4UJiIi6qKutC6QfJ8f/xWIiIioV2CoISIiol6BoYaIiIi6JDs7G2vXrlW7GhwoTEREdDm6/vrrkZeXJ0sY+fjjj5GUlNT1SnURQ00XHa1uxOaPT6G/1YQHJg9RuzpERESyEAQBfr8fev2lo0L//v3jUKNLY/dTF51uaMH/fnACrx48rXZViIhIZYIg4LzHp8pFEIQO13PWrFl499138eyzz0Kj0UCj0WDjxo3QaDT429/+hnHjxsFkMuGDDz7A8ePHcfvttyMtLQ0WiwXXXHMN/vGPf0Sc78LuJ41Gg9/97ne44447kJiYiNzcXLz22mtyfcxtYktNF1lM4kfY5PaqXBMiIlJbs9ePkaV/V+W1P3uiCInGjv1af/bZZ3HkyBGMGjUKTzzxBADgX//6FwDg0UcfxdNPP40rrrgCffv2xalTp3DrrbfiySefhMlkwh//+EfcdtttOHz4MAYNGtTmazz++ONYtWoVVq9ejeeeew733nsvvvrqK6SkpHT9zbaBLTVdZDUHQ02LT+WaEBERdYzNZoPRaERiYiLsdjvsdjt0Oh0A4IknnsBNN92EIUOGICUlBWPHjsX999+PUaNGITc3F7/+9a8xZMiQS7a8zJo1C9OnT8fQoUPx1FNPoampCXv27FH0fbGlpovCLTVi0x+XxyYiunwlGHT47Iki1V5bDuPHj4+439TUhMceewx//etfcebMGfh8PjQ3N6Oqqqrd84wZM0a6nZSUhOTkZNTU1MhSx7Yw1HRRqKXG6xfg9gVglulLRUREPY9Go+lwF1B3deEspkWLFmHXrl14+umnMXToUCQkJOBHP/oRPB5Pu+cxGAwR9zUaDQKBgOz1ba1nf/LdQFKrL29ji4+hhoiIegSj0Qi/33/Jch9++CFmzZqFO+64A4DYcnPy5EmFa9c5HFPTRVqtJqILioiIqCfIzs5GZWUlTp48ibq6ujZbUXJzc7Ft2zYcPHgQn3zyCX7yk58o3uLSWQw1MpBCDQcLExFRD7Fo0SLodDqMHDkS/fv3b3OMzDPPPIO+ffti4sSJuO2221BUVISrr746zrXtGHY/ycBi1gNOoJHTuomIqIcYNmwYKioqIo7NmjXronLZ2dl46623Io7Nnz8/4v6F3VHR1sypr6/vVD1jwZYaGXBaNxERkfoYamQQ6n5qZKghIiJSDUONDKSWGg4UJiIiUg1DjQw4+4mIiEh9DDUysJjEBYbY/URERKQehhoZhLufOPuJiIhILQw1MgiFGrbUEBERqYehRgZcfI+IiEh9DDUysIRaajhQmIiISDUMNTJgSw0REfU0119/PRYuXCjb+WbNmoXi4mLZztcZDDUysJqDs584UJiIiEg1DDUy4DYJRETUk8yaNQvvvvsunn32WWg0Gmg0Gpw8eRKHDh3CLbfcAovFgrS0NNx3332oq6uTnvf//t//w+jRo5GQkIB+/fqhsLAQLpcLjz32GP7whz/g1Vdflc73zjvvxP19cUNLGbRefE8QBGg0GpVrREREqhAEwHtendc2JAId/P3z7LPP4siRIxg1ahSeeOIJ8ekGAyZMmIB/+7d/w29/+1s0Nzfjl7/8JX784x/jrbfewpkzZzB9+nSsWrUKd9xxBxobG/H+++9DEAQsWrQIn3/+OZxOJ1588UUAQEpKimJvtS0MNTIIDRT2+gW4fQGYDTqVa0RERKrwngeeylDntZeeBoxJHSpqs9lgNBqRmJgIu90OAFi5ciWuuuoqPPXUU1K53//+98jKysKRI0fQ1NQEn8+HO++8E4MHDwYAjB49WiqbkJAAt9stnU8Nnep+Wr9+PbKzs2E2m5Gfn489e/a0W37r1q0YMWIEzGYzRo8ejR07dkQ8LggCSktLkZ6ejoSEBBQWFuLo0aMXneevf/0r8vPzkZCQgL59+6o+ICnEYgxnQ26VQEREPdEnn3yCt99+GxaLRbqMGDECAHD8+HGMHTsWN954I0aPHo27774b//M//4Nvv/1W5VpHirmlZsuWLSgpKcGGDRuQn5+PtWvXoqioCIcPH8aAAQMuKr97925Mnz4dZWVl+MEPfoBNmzahuLgY+/fvx6hRowAAq1atwrp16/CHP/wBOTk5WL58OYqKivDZZ5/BbDYDAF555RXMnTsXTz31FL7//e/D5/Ph0KFDXXz78tBqNbCY9Ghy+9DY4kOqxaR2lYiISA2GRLHFRK3X7oKmpibcdttt+I//+I+LHktPT4dOp8OuXbuwe/duvPnmm3juuefwq1/9CpWVlcjJyenSa8tGiNGECROE+fPnS/f9fr+QkZEhlJWVRS3/4x//WJg6dWrEsfz8fOH+++8XBEEQAoGAYLfbhdWrV0uP19fXCyaTSXjppZcEQRAEr9crZGZmCr/73e9ira6koaFBACA0NDR0+hztyX/yH8LgX74hfHqqXpHzExFR99Pc3Cx89tlnQnNzs9pVidlNN90k/OIXv5DuL126VBg+fLjg9Xo79HyfzydkZmYKa9asEQRBEObOnSv84Ac/6FRd2vscY/n9HVP3k8fjwb59+1BYWCgd02q1KCwsREVFRdTnVFRURJQHgKKiIqn8iRMn4HA4IsrYbDbk5+dLZfbv349vvvkGWq0WV111FdLT03HLLbd0m5YaoPUCfJzWTURE3V92djYqKytx8uRJ1NXVYf78+Th37hymT5+Ojz/+GMePH8ff//53zJ49G36/H5WVlXjqqaewd+9eVFVVYdu2baitrcWVV14pne/TTz/F4cOHUVdXB683/r8PYwo1dXV18Pv9SEtLizielpYGh8MR9TkOh6Pd8qHr9sp8+eWXAIDHHnsMy5YtwxtvvIG+ffvi+uuvx7lz56K+rtvthtPpjLgoiQvwERFRT7Jo0SLodDqMHDkS/fv3h8fjwYcffgi/34+bb74Zo0ePxsKFC9GnTx9otVokJyfjvffew6233ophw4Zh2bJlWLNmDW655RYAwNy5czF8+HCMHz8e/fv3x4cffhj399QjZj8FAgEAwK9+9SvcddddAIAXX3wRAwcOxNatW3H//fdf9JyysjI8/vjjcatjeKduhhoiIur+hg0bFrWXZdu2bVHLX3nlldi5c2eb5+vfvz/efPNN2erXGTG11KSmpkKn06G6ujrieHV1dZtTuOx2e7vlQ9ftlUlPTwcAjBw5UnrcZDLhiiuuQFVVVdTXXbJkCRoaGqTLqVOnOvo2O4U7dRMREakrplBjNBoxbtw4lJeXS8cCgQDKy8tRUFAQ9TkFBQU
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'], label='train')\n",
"plt.plot(history.history['val_loss'], label='test')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1562, 1)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lstm_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1562,)"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y_pre.shape"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"test_y_pre1=test_y_pre.reshape(1562,1)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.90540195],\n",
" [0.90466702],\n",
" [0.89696645],\n",
" ...,\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]])"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y_pre1"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"results1 = np.broadcast_to(lstm_pred, (1562, 6))"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"test_y2 = np.broadcast_to(test_y_pre1, (1562, 6))"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"# 反归一化\n",
"inv_forecast_y = scaler.inverse_transform(results1)\n",
"inv_test_y = scaler.inverse_transform(test_y2)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 4.08374339e+01, 9.30529669e+01, 1.27546074e+03,\n",
" 5.95908965e+02, 2.15485743e+01, 4.67959929e+00],\n",
" [ 4.08001461e+01, 9.29803001e+01, 1.27442567e+03,\n",
" 5.95425556e+02, 2.15310831e+01, 4.67580080e+00],\n",
" [ 4.04094426e+01, 9.22188951e+01, 1.26358018e+03,\n",
" 5.90360385e+02, 2.13478095e+01, 4.63600016e+00],\n",
" ...,\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00]])"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_test_y"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test RMSE: 0.063\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABP4AAAKTCAYAAACJusZ+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5xseV3n/9e3clVX6nhznDt5mGFmSEMcBxBEBcT8Q5E1u6CyigF32V1BHddlUXSVRdffD1ERFl3XXVlEUKLAAAPDDJPv3Bw6V87h/P74nuruOzd3V9Wp8H4+HvO4t/p2d30u3NPnnM/5BOM4joOIiIiIiIiIiIiMFJ/XAYiIiIiIiIiIiEj3KfEnIiIiIiIiIiIygpT4ExERERERERERGUFK/ImIiIiIiIiIiIwgJf5ERERERERERERGkBJ/IiIiIiIiIiIiI0iJPxERERERERERkREU6Pcbttttzpw5QyKRwBjT77cXEREREREREREZao7jUCgU2LlzJz7fxev6+p74O3PmDHv27On324qIiIiIiIiIiIyUkydPsnv37ov+ed8Tf4lEArCBJZPJfr+9iIiIiIiIiIjIUMvn8+zZs2ctz3YxfU/8ddp7k8mkEn8iIiIiIiIiIiKbdLkxelruISIiIiIiIiIiMoKU+BMRERERERERERlBSvyJiIiIiIiIiIiMoL7P+BMRERERERERkfHTarVoNBpehzEUgsEgfr9/y99HiT8REREREREREekZx3GYn58nm816HcpQSafTbN++/bILPC7lqhJ///E//kd+/dd//ZyPXX/99Tz22GObDkBEREREREREREZXJ+k3NzdHLBbbUiJrHDiOQ7lcZnFxEYAdO3Zs+ntddcXfzTffzCc/+cn1bxBQ0aCIiIiIiIiIiJyv1WqtJf2mp6e9DmdoRKNRABYXF5mbm9t02+9VZ+0CgQDbt2/f1JuJiIiIiIiIiMj46Mz0i8ViHkcyfDr/mzUajU0n/q56q++TTz7Jzp07OXjwIK9//es5ceLEJT+/VquRz+fP+U9ERERERERERMaH2nuvXjf+N7uqxN9zn/tc3v/+9/MP//APvPe97+Xo0aO86EUvolAoXPRr7r33XlKp1Np/e/bs2XLQIiIiIiIiIiIicmnGcRxns1+czWbZt28f7373u/mxH/uxC35OrVajVqutvc7n8+zZs4dcLkcymdzsW4uIiIiIiIiIyICrVqscPXqUAwcOEIlEvA5nqFzqf7t8Pk8qlbpsfm1LmznS6TTXXXcdhw8fvujnhMNhwuHwVt5GRERERERERERErtJVz/jbqFgs8tRTT21prbCIiIiIiIiIiMigufvuu3nLW97idRhbclWJv7e+9a185jOf4dixY3zhC1/gu77ru/D7/fzgD/5gr+ITEREREREREREZOI7j0Gw2vQ7jkq4q8Xfq1Cl+8Ad/kOuvv57v+77vY3p6mi996UvMzs72Kj4RERERERERERkhjuNQrjf7/t/VrLl44xvfyGc+8xne8573YIzBGMP73/9+jDF87GMf48477yQcDvP5z3+eN77xjbz2ta895+vf8pa3cPfdd6+9brfb3HvvvRw4cIBoNMptt93GX//1X3fpf9GLu6oZfx/60Id6FYeIiIiIiIiIiIyBSqPFTf/+431/30fe8QpioStLhb3nPe/hiSee4JZbbuEd73gHAA8//DAAv/qrv8q73vUuDh48yOTk5BV9v3vvvZe/+Iu/4L/9t//Gtddey2c/+1l+6Id+iNnZWV7ykpds7i90Bba03ENERERERERERGTUpFIpQqEQsViM7du3A/DYY48B8I53vIOXv/zlV/y9arUav/Vbv8UnP/lJ7rrrLgAOHjzI5z//ed73vvcp8SciIiIiIiIiIqMhGvTzyDte4cn7dsOznvWsq/r8w4cPUy6Xz0sW1ut1br/99q7EdDFK/ImIiIiIiIiISN8YY6645XYQTUxMnPPa5/OdNz+w0Wis/b5YLALw0Y9+lF27dp3zeeFwuEdRWsP7v7KIiIiIiIiIiEiPhEIhWq3WZT9vdnaWb37zm+d87IEHHiAYDAJw0003EQ6HOXHiRE/bei9EiT8REREREREREZGn2b9/P/fddx/Hjh0jHo/Tbrcv+Hn33HMP//k//2c+8IEPcNddd/EXf/EXfPOb31xr400kErz1rW/l3/ybf0O73eaFL3whuVyOf/mXfyGZTPIjP/IjPfs7+Hr2nUVERERERERERIbUW9/6Vvx+PzfddBOzs7OcOHHigp/3ile8gre//e388i//Ms9+9rMpFAq84Q1vOOdz3vnOd/L2t7+de++9lxtvvJFXvvKVfPSjH+XAgQM9/TsY5+lNyD2Wz+dJpVLkcjmSyWQ/31pERERERERERPqoWq1y9OhRDhw4QCQS8TqcoXKp/+2uNL+mij8REREREREREZERpMSfiIiIiIiIiMgWtdsOxVoTgL++/xRv/cg3yJUbl/kqkd7Scg8RERERERERkS16+999k4989RQf+LHn8Ct/8yCttsPDZ/L83597IcYYr8OTMaWKPxERERERERGRLVgt1fnL+05Qb7X5gT/+Eq22Xafw6Nk8H3943uPoZJwp8SciIiIiIiIisgUffejsRf/sb79+uo+RiJxLiT8RERERERGhWGvy+//0JO/+xBNUGy2vwxEZKt84mT3vY3unYgCczlb6HI3IOs34ExERERERET7wxWO8+xNPAJCOBvnRFx7wOCKR4fHYfB6A3/v+Z/LAySyrpTqvesYOfvov7udstupxdDLOlPgTERERERER/uXw8trv3/NPT/Ldd+7GcRz+22eO8IJD07zo2lkPoxMZXM1WmycWigDcvjfNa2/fBUCmVAdgpVSn2mgRCfo9i1HGlxJ/IiIiIiIiY67aaPHVY5m117lKg9t+/R/XXv+3zzzF//ipu3jOgSkvwhMZWN88neOn/vx+6s02EyE/eyZja3+WjgUJB3zUmm0W8lX2TU94GKmMK834ExERERERGXOffWKJWrPN9mSEd33vbcTD59eIfPaJJQ8iExlclXqLn/rz+9dm+L3ujt34fGbtz40x7ExHATibq1Kpa3am9J8SfyIiIiIiImOs1mzxJ587AsCrn7mT77lzNx/+qefx2mfu5JdecT3fc+duAI6ulHjwVJb3feYp6s22lyGLDIS/vO/4WtLvBYem+eVXXn/e5+xIRQD4k88e4ab/8A98+Csn+hqj9Fe9Xvc6hPMo8SciIiIiIjLG3v2JJ/jKsQx+n+F73STfzTtT/N4P3M6bvuUQr7h5OwBHl0r8yP/7Ze792GP84acOexmyiOfabYf3f+EYAL/1Xc/gL3/8eSQiwfM+r7PZ958eW8Rx4Ff+5qF+hilbdPfdd/PmN7+ZN7/5zaRSKWZmZnj729+O4zgA7N+/n3e+85284Q1vIJlM8pM/+ZMAfP7zn+dFL3oR0WiUPXv28HM/93OUSiVP/g5K/IlIT33qsUU+eJ+eaomIiIgMolbb4W/uPw3Af/zOm7h2W+K8zzkwY+eSPXI2T6bcAOCPP3uEZktVfzK+TmbKnMpUCAd8fJe7zONCDs3Fz/tYJ2k01hwH6qX+/7eJ/+3/7M/+jEAgwJe//GXe85738O53v5v//t//+9qfv+td7+K2227j61//Om9/+9t56qmneOUrX8l3f/d38+CDD/LhD3+Yz3/+87z5zW/u5v+CV0zLPUSk606ulvmTzx3hmtk4/+F/PwzArbtT3LIrxXKxRioaJOjXcwcRERERr33l2Ora9dn3P3vvBT+nU7G0UaXR4qHTOW7fO9nrEEUG0pFlW711YGaCaGjDtt7HP2YTTLd8NxjDNRdI/B1bKa8l1MdWowy/tbP/7/trZyB0df/b79mzh9/93d/FGMP111/PQw89xO/+7u/yEz/xEwDcc889/OIv/uLa5//4j/84r3/963nLW94CwLXXXsvv//7v85KXvIT3vve9RCKRrv11roTuvEWk637sz77CB754fC3pB/Do2TxfPbbKc3/rn/jZD37dw+hEREREpOPLR1cBeMl1s4QCF74
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 计算均方根误差\n",
"rmse = sqrt(mean_squared_error(inv_test_y[:,5], inv_forecast_y[:,5]))\n",
"print('Test RMSE: %.3f' % rmse)\n",
"#画图\n",
"plt.figure(figsize=(16,8))\n",
"plt.plot(inv_test_y[:,5], label='true')\n",
"plt.plot(inv_forecast_y[:,5], label='pre')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_squared_error: 0.00014629570256978046\n",
"mean_absolute_error: 0.008445659571024366\n",
"rmse: 0.01209527604355438\n",
"r2 score: 0.9988370101682903\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error # 评价指标\n",
"# 使用sklearn调用衡量线性回归的MSE 、 RMSE、 MAE、r2\n",
"from math import sqrt\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import r2_score\n",
"print('mean_squared_error:', mean_squared_error(lstm_pred, test_y_pre)) # mse)\n",
"print(\"mean_absolute_error:\", mean_absolute_error(lstm_pred, test_y_pre)) # mae\n",
"print(\"rmse:\", sqrt(mean_squared_error(lstm_pred,test_y_pre)))\n",
"#r2对比区域\n",
"print(\"r2 score:\", r2_score(inv_test_y[:], inv_forecast_y[:]))#预测50天数据"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame(inv_test_y[:,5], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df1.to_csv('test.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"df2 = pd.DataFrame(inv_forecast_y[:], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df2.to_csv('forecast.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}