ICEEMDAN-Solar_power-forecast/ConvBigru_IRPE_Attention.ipynb

938 lines
377 KiB
Plaintext
Raw Normal View History

2024-08-01 10:47:23 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\computation\\expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.8.3' currently installed).\n",
" from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
" from pandas.core import (\n"
]
}
],
"source": [
"from math import sqrt\n",
"from numpy import concatenate\n",
"from matplotlib import pyplot\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error\n",
"from tensorflow.keras import Sequential\n",
"\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import LSTM\n",
"from tensorflow.keras.layers import Dropout\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这段代码是一个函数 time_series_to_supervised它用于将时间序列数据转换为监督学习问题的数据集。下面是该函数的各个部分的含义\n",
"\n",
"data: 输入的时间序列数据可以是列表或2D NumPy数组。\n",
"n_in: 作为输入的滞后观察数即用多少个时间步的观察值作为输入。默认值为96表示使用前96个时间步的观察值作为输入。\n",
"n_out: 作为输出的观测数量即预测多少个时间步的观察值。默认值为10表示预测未来10个时间步的观察值。\n",
"dropnan: 布尔值表示是否删除具有NaN值的行。默认为True即删除具有NaN值的行。\n",
"函数首先检查输入数据的维度并初始化一些变量。然后它创建一个新的DataFrame对象 df 来存储输入数据,并保存原始的列名。接着,它创建了两个空列表 cols 和 names用于存储新的特征列和列名。\n",
"\n",
"接下来,函数开始构建特征列和对应的列名。首先,它将原始的观察序列添加到 cols 列表中,并将其列名添加到 names 列表中。然后,它依次将滞后的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t-滞后时间)。这样就创建了输入特征的部分。\n",
"\n",
"接着,函数开始构建输出特征的部分。它依次将未来的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t+未来时间)。\n",
"\n",
"最后函数将所有的特征列拼接在一起构成一个新的DataFrame对象 agg。如果 dropnan 参数为True则删除具有NaN值的行。最后函数返回处理后的数据集 agg。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def time_series_to_supervised(data, n_in=96, n_out=10,dropnan=True):\n",
" \"\"\"\n",
" :param data:作为列表或2D NumPy数组的观察序列。需要。\n",
" :param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。\n",
" :param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。\n",
" :param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。\n",
" :return:\n",
" \"\"\"\n",
" n_vars = 1 if type(data) is list else data.shape[1]\n",
" df = pd.DataFrame(data)\n",
" origNames = df.columns\n",
" cols, names = list(), list()\n",
" cols.append(df.shift(0))\n",
" names += [('%s' % origNames[j]) for j in range(n_vars)]\n",
" n_in = max(0, n_in)\n",
" for i in range(n_in, 0, -1):\n",
" time = '(t-%d)' % i\n",
" cols.append(df.shift(i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" n_out = max(n_out, 0)\n",
" for i in range(1, n_out+1):\n",
" time = '(t+%d)' % i\n",
" cols.append(df.shift(-i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" agg = pd.concat(cols, axis=1)\n",
" agg.columns = names\n",
" if dropnan:\n",
" agg.dropna(inplace=True)\n",
" return agg"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall Power\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 0.0\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 0.0\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 0.0\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 0.0\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 0.0\n",
"(104256, 6)\n"
]
}
],
"source": [
"# 加载数据\n",
"path1 = r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\data6.csv\"#数据所在路径\n",
"#我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"datas1 = pd.DataFrame(pd.read_csv(path1))\n",
"#我只取了data表里的第3、23、16、17、18、19、20、21、27列如果取全部列的话这一行可以去掉\n",
"# data1 = datas1.iloc[:,np.r_[3,23,16:22,27]]\n",
"data1=datas1.interpolate()\n",
"values1 = data1.values\n",
"print(data1.head())\n",
"print(data1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# data2= data1.drop(['date','Air_P','RH'], axis = 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(104256, 6)\n"
]
}
],
"source": [
"# 使用MinMaxScaler进行归一化\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"scaledData1 = scaler.fit_transform(data1)\n",
"print(scaledData1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 1 2 3 4 5 0(t-96) \\\n",
"96 0.555631 0.349673 0.190042 0.040558 0.0 0.236302 0.490360 \n",
"97 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 0.489088 \n",
"98 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 0.485824 \n",
"99 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 0.482997 \n",
"100 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 0.482632 \n",
"\n",
" 1(t-96) 2(t-96) 3(t-96) ... 2(t-1) 3(t-1) 4(t-1) 5(t-1) \\\n",
"96 0.369105 0.002088 0.002013 ... 0.166009 0.036794 0.0 0.214129 \n",
"97 0.364859 0.002061 0.001839 ... 0.190042 0.040558 0.0 0.236302 \n",
"98 0.364886 0.001973 0.001839 ... 0.211335 0.044613 0.0 0.258204 \n",
"99 0.365615 0.001950 0.001697 ... 0.229657 0.047549 0.0 0.279860 \n",
"100 0.361965 0.001679 0.001167 ... 0.247775 0.053347 0.0 0.301336 \n",
"\n",
" 0(t+1) 1(t+1) 2(t+1) 3(t+1) 4(t+1) 5(t+1) \n",
"96 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 \n",
"97 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 \n",
"98 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 \n",
"99 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 \n",
"100 0.590772 0.258790 0.282900 0.060958 0.0 0.343360 \n",
"\n",
"[5 rows x 588 columns]\n"
]
}
],
"source": [
"n_steps_in =96 #历史时间长度\n",
"n_steps_out=1#预测时间长度\n",
"processedData1 = time_series_to_supervised(scaledData1,n_steps_in,n_steps_out)\n",
"print(processedData1.head())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data_x = processedData1.loc[:,'0(t-96)':'5(t-1)']#去除power剩下的做标签列\n",
"data_y = processedData1.loc[:,'5']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159, 576)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_x.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"96 0.236302\n",
"97 0.258204\n",
"98 0.279860\n",
"99 0.301336\n",
"100 0.322851\n",
" ... \n",
"104250 0.000000\n",
"104251 0.000000\n",
"104252 0.000000\n",
"104253 0.000000\n",
"104254 0.000000\n",
"Name: 5, Length: 104159, dtype: float64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159,)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(83328, 96, 6) (83328,) (20831, 96, 6) (20831,)\n"
]
}
],
"source": [
"# 7.划分训练集和测试集\n",
"\n",
"test_size = int(len(data_x) * 0.2)\n",
"# 计算训练集和测试集的索引范围\n",
"train_indices = range(len(data_x) - test_size)\n",
"test_indices = range(len(data_x) - test_size, len(data_x))\n",
"\n",
"# 根据索引范围划分数据集\n",
"train_X1 = data_x.iloc[train_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"test_X1 = data_x.iloc[test_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"train_y = data_y.iloc[train_indices].values\n",
"test_y = data_y.iloc[test_indices].values\n",
"\n",
"\n",
"# # 多次运行代码时希望得到相同的数据分割,可以设置 random_state 参数为一个固定的整数值\n",
"# train_X1,test_X1, train_y, test_y = train_test_split(data_x.values, data_y.values, test_size=0.2, random_state=343)\n",
"# reshape input to be 3D [samples, timesteps, features]\n",
"train_X = train_X1.reshape((train_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"test_X = test_X1.reshape((test_X1.shape[0], n_steps_in,scaledData1.shape[1]))\n",
"print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)\n",
"# 使用train_test_split函数划分训练集和测试集测试集的比重是40%。\n",
"# 然后将train_X1、test_X1进行一个升维变成三维维数分别是[samples,timesteps,features]。\n",
"# 打印一下他们的shape\\\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(83328, 96, 6)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_X1.shape"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From d:\\Anaconda3\\lib\\site-packages\\keras\\src\\backend\\tensorflow\\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">6</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv1D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">832</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling1D</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">49,920</span> │ max_pooling1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,304</span> │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AttentionWithImpr…</span> │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)] │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ attention_with_i… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePool…</span> │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">129</span> │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m832\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling1D\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,920\u001b[0m │ max_pooling1d[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m66,304\u001b[0m │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mAttentionWithImpr…\u001b[0m │ \u001b[38;5;34m128\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention_with_i… │\n",
"│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Conv1D, Bidirectional, GlobalAveragePooling1D, Dense, GRU, MaxPooling1D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.initializers import RandomUniform\n",
"class AttentionWithImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, max_len=5000):\n",
" super(AttentionWithImproveRelativePositionEncoding, self).__init__()\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
" self.max_len = max_len\n",
" self.wq = tf.keras.layers.Dense(d_model)\n",
" self.wk = tf.keras.layers.Dense(d_model)\n",
" self.wv = tf.keras.layers.Dense(d_model)\n",
" self.dense = tf.keras.layers.Dense(d_model)\n",
" self.position_encoding = ImproveRelativePositionEncoding(d_model)\n",
"\n",
" def call(self, v, k, q, mask):\n",
" batch_size = tf.shape(q)[0]\n",
" q = self.wq(q)\n",
" k = self.wk(k)\n",
" v = self.wv(v)\n",
"\n",
" # 添加位置编码\n",
" k += self.position_encoding (k)\n",
" q += self.position_encoding (q)\n",
"\n",
" q = self.split_heads(q, batch_size)\n",
" k = self.split_heads(k, batch_size)\n",
" v = self.split_heads(v, batch_size)\n",
"\n",
" scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)\n",
" scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])\n",
" concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))\n",
" output = self.dense(concat_attention)\n",
" return output, attention_weights\n",
"\n",
" def split_heads(self, x, batch_size):\n",
" x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_model // self.num_heads))\n",
" return tf.transpose(x, perm=[0, 2, 1, 3])\n",
"\n",
" def scaled_dot_product_attention(self, q, k, v, mask):\n",
" matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
" dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
" scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
"\n",
" if mask is not None:\n",
" scaled_attention_logits += (mask * -1e9)\n",
"\n",
" attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
" output = tf.matmul(attention_weights, v)\n",
" return output, attention_weights\n",
"\n",
"class ImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, max_len=5000):\n",
" super(ImproveRelativePositionEncoding, self).__init__()\n",
" self.max_len = max_len\n",
" self.d_model = d_model\n",
" # 引入可变化的参数u和v进行线性变化\n",
" self.u = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" self.v = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" def call(self, inputs):\n",
" seq_length = inputs.shape[1]\n",
" pos_encoding = self.relative_positional_encoding(seq_length, self.d_model)\n",
" \n",
" # 调整原始的相对位置编码公式将u和v参数融入其中\n",
" pe_with_params = pos_encoding * self.u+ pos_encoding * self.v\n",
" return inputs + pe_with_params\n",
"\n",
" def relative_positional_encoding(self, position, d_model):\n",
" pos = tf.range(position, dtype=tf.float32)\n",
" i = tf.range(d_model, dtype=tf.float32)\n",
" \n",
" angles = 1 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(d_model, tf.float32))\n",
" angle_rads = tf.einsum('i,j->ij', pos, angles)\n",
" #保留了sinous机制\n",
" # Apply sin to even indices; 2i\n",
" angle_rads_sin = tf.sin(angle_rads[:, 0::2])\n",
" # Apply cos to odd indices; 2i+1\n",
" angle_rads_cos = tf.cos(angle_rads[:, 1::2])\n",
"\n",
" pos_encoding = tf.stack([angle_rads_sin, angle_rads_cos], axis=2)\n",
" pos_encoding = tf.reshape(pos_encoding, [1, position, d_model])\n",
"\n",
" return pos_encoding\n",
"\n",
"\n",
"\n",
"def PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads):\n",
" inputs = Input(shape=input_shape)\n",
" # CNN layer\n",
" cnn_layer = Conv1D(filters=64, kernel_size=2, activation='relu')(inputs)\n",
" cnn_layer = MaxPooling1D(pool_size=1)(cnn_layer)\n",
" gru_output = Bidirectional(GRU(gru_units, return_sequences=True))(cnn_layer)\n",
" \n",
" # Apply Self-Attention\n",
" self_attention =AttentionWithImproveRelativePositionEncoding(d_model=gru_units*2, num_heads=num_heads)\n",
" gru_output, _ = self_attention(gru_output, gru_output, gru_output, mask=None)\n",
" \n",
" pool1 = GlobalAveragePooling1D()(gru_output)\n",
" output = Dense(1)(pool1)\n",
" \n",
" return Model(inputs=inputs, outputs=output)\n",
"\n",
"\n",
"input_shape = (96, 6)\n",
"gru_units = 64\n",
"num_heads = 8\n",
"\n",
"# Create model\n",
"model = PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads)\n",
"model.compile(optimizer='adam', loss='mse')\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 61ms/step - loss: 0.0187 - val_loss: 0.0021\n",
"Epoch 2/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 62ms/step - loss: 0.0014 - val_loss: 0.0025\n",
"Epoch 3/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m84s\u001b[0m 64ms/step - loss: 0.0013 - val_loss: 0.0021\n",
"Epoch 4/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m87s\u001b[0m 67ms/step - loss: 0.0012 - val_loss: 0.0020\n",
"Epoch 5/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m76s\u001b[0m 58ms/step - loss: 0.0013 - val_loss: 0.0019\n",
"Epoch 6/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m71s\u001b[0m 55ms/step - loss: 0.0011 - val_loss: 0.0020\n",
"Epoch 7/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m70s\u001b[0m 54ms/step - loss: 0.0011 - val_loss: 0.0019\n",
"Epoch 8/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m81s\u001b[0m 62ms/step - loss: 0.0011 - val_loss: 0.0020\n",
"Epoch 9/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 63ms/step - loss: 0.0012 - val_loss: 0.0019\n",
"Epoch 10/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m85s\u001b[0m 65ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"Epoch 11/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m88s\u001b[0m 68ms/step - loss: 0.0011 - val_loss: 0.0019\n",
"Epoch 12/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m90s\u001b[0m 69ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"Epoch 13/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 61ms/step - loss: 0.0011 - val_loss: 0.0020\n",
"Epoch 14/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 64ms/step - loss: 0.0011 - val_loss: 0.0019\n",
"Epoch 15/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 63ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"Epoch 16/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 61ms/step - loss: 0.0011 - val_loss: 0.0019\n",
"Epoch 17/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m81s\u001b[0m 63ms/step - loss: 0.0010 - val_loss: 0.0020\n",
"Epoch 18/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m77s\u001b[0m 59ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"Epoch 19/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m78s\u001b[0m 60ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"Epoch 20/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m77s\u001b[0m 59ms/step - loss: 0.0011 - val_loss: 0.0018\n",
"\u001b[1m651/651\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 17ms/step\n"
]
}
],
"source": [
"# Compile and train the model\n",
"model.compile(optimizer='adam', loss='mean_squared_error')\n",
"from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
"\n",
"# 定义早停机制\n",
"early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')\n",
"\n",
"# 拟合模型,并添加早停机制和模型检查点\n",
"history = model.fit(train_X, train_y, epochs=100, batch_size=64, validation_data=(test_X, test_y), \n",
" callbacks=[early_stopping])\n",
"# 预测\n",
"lstm_pred = model.predict(test_X)\n",
"# 将预测结果的形状修改为与原始数据相同的形状"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOSklEQVR4nO3de1xUZeI/8M+ZOwwwgCgDikKGqYl3JdT9WrtsWG4rtZW5u3lZV9v9ta0u2cVWpdtGabXm5Ru1m2m739LcynbLLKN7EuatstLMBVFxUCQYGGCGmTm/P87MwMBwGZgrfN6v17xm5sxzDs9hHOfDczuCKIoiiIiIiMKcLNgVICIiIvIFhhoiIiLqExhqiIiIqE9gqCEiIqI+gaGGiIiI+gSGGiIiIuoTGGqIiIioT2CoISIioj5BEewKBIrdbkdFRQWio6MhCEKwq0NERETdIIoi6urqkJycDJms87aYfhNqKioqkJKSEuxqEBERUQ+cPn0aQ4YM6bRMvwk10dHRAKRfSkxMTJBrQ0RERN1hNBqRkpLi+h7vTL8JNc4up5iYGIYaIiKiMNOdoSMcKExERER9AkMNERER9QkMNURERNQn9JsxNURERP4iiiKsVitsNluwqxJ25HI5FAqFT5ZbYaghIiLqBYvFgnPnzqGhoSHYVQlbkZGRSEpKgkql6tVxGGqIiIh6yG63o7S0FHK5HMnJyVCpVFzg1QuiKMJiseDChQsoLS1Fenp6lwvsdYahhoiIqIcsFgvsdjtSUlIQGRkZ7OqEpYiICCiVSpw6dQoWiwUajabHx+JAYSIiol7qTesC+e73x3eBiIiI+gSGGiIiIuoTGGqIiIioV1JTU7F+/fpgV4MDhYmIiPqjK6+8EuPHj/dJGPn888+h1Wp7X6leYqjppe8q6/Dy56cxIEqN3185PNjVISIi8glRFGGz2aBQdB0VBg4cGIAadY3dT710rrYJf/+kFK8fORvsqhARUZCJoogGizUoN1EUu13PhQsX4sMPP8RTTz0FQRAgCAK2bt0KQRDw1ltvYdKkSVCr1fjkk09w8uRJzJkzB4mJiYiKisKUKVPw7rvvuh2vbfeTIAj4+9//juuvvx6RkZFIT0/Hv//9b1/9mjvElppeiotUAgBqGpqDXBMiIgq2xmYbRq95Oyg/+5sHcxCp6t7X+lNPPYXvvvsOY8aMwYMPPggA+PrrrwEA9957Lx5//HFccskliIuLw+nTp3HttdfiL3/5C9RqNV544QVcd911OH78OIYOHdrhz3jggQewdu1arFu3Dhs3bsSvfvUrnDp1CvHx8b0/2Q6wpaaX4iKlJZ1/aLAEuSZERETdo9PpoFKpEBkZCb1eD71eD7lcDgB48MEH8dOf/hTDhw9HfHw8xo0bh9tuuw1jxoxBeno6HnroIQwfPrzLlpeFCxdi3rx5uPTSS/HII4+gvr4e+/fv9+t59ailZvPmzVi3bh0MBgPGjRuHjRs3YurUqR2W37lzJ1avXo2ysjKkp6fjsccew7XXXut6XRRF5Ofn429/+xtqamowffp0PP3000hPT3c7zptvvokHH3wQX375JTQaDWbOnIldu3b15BR8JtbRUmO22tFosSFCJQ9qfYiIKHgilHJ882BO0H62L0yePNnteX19Pe6//368+eabOHfuHKxWKxobG1FeXt7pccaOHet6rNVqERMTg/Pnz/ukjh3xuqVmx44dyMvLQ35+Pg4dOoRx48YhJyenw4ru27cP8+bNw+LFi3H48GHk5uYiNzcXR48edZVZu3YtNmzYgMLCQpSUlECr1SInJwdNTU2uMq+88gpuvfVWLFq0CF988QU+/fRT/PKXv+zBKftWlFoBhUy6zgdba4iI+jdBEBCpUgTl5qtrTrWdxbRixQq89tpreOSRR/Dxxx/jyJEjyMjIgMXS+XeeUqls97ux2+0+qWNHvA41Tz75JJYsWYJFixZh9OjRKCwsRGRkJLZs2eKx/FNPPYVZs2bhrrvuwqhRo/DQQw9h4sSJ2LRpEwCplWb9+vVYtWoV5syZg7Fjx+KFF15ARUWFqxXGarVi2bJlWLduHX73u99hxIgRGD16NG6++eaen7mPCIKAOC27oIiIKLyoVCrYbLYuy3366adYuHAhrr/+emRkZECv16OsrMz/FewBr0KNxWLBwYMHkZ2d3XIAmQzZ2dkoLi72uE9xcbFbeQDIyclxlS8tLYXBYHAro9PpkJmZ6Spz6NAhnD17FjKZDBMmTEBSUhKuueYat9aetsxmM4xGo9vNX5yDhX8wcbAwERGFh9TUVJSUlKCsrAxVVVUdtqKkp6fj1VdfxZEjR/DFF1/gl7/8pd9bXHrKq1BTVVUFm82GxMREt+2JiYkwGAwe9zEYDJ2Wd953Vua///0vAOD+++/HqlWr8MYbbyAuLg5XXnklqqurPf7cgoIC6HQ61y0lJcWbU/VKLAcLExFRmFmxYgXkcjlGjx6NgQMHdjhG5sknn0RcXBymTZuG6667Djk5OZg4cWKAa9s9YTGl25kI//znP+MXv/gFAOD555/HkCFDsHPnTtx2223t9lm5ciXy8vJcz41Go9+CTcu0boYaIiIKDyNGjGjXy7Jw4cJ25VJTU/Hee++5bbv99tvdnrftjvK0Zk5NTU2P6ukNr1pqEhISIJfLUVlZ6ba9srISer3e4z56vb7T8s77zsokJSUBAEaPHu16Xa1W45JLLukwWarVasTExLjd/KVlWje7n4iIiILFq1CjUqkwadIkFBUVubbZ7XYUFRUhKyvL4z5ZWVlu5QFg7969rvJpaWnQ6/VuZYxGI0pKSlxlnCsbHj9+3FWmubkZZWVlGDZsmDen4BfsfiIiIgo+r7uf8vLysGDBAkyePBlTp07F+vXrYTKZsGjRIgDA/PnzMXjwYBQUFAAAli1bhpkzZ+KJJ57A7NmzsX37dhw4cADPPvssAGn20PLly/Hwww8jPT0daWlpWL16NZKTk5GbmwsAiImJwe9+9zvk5+cjJSUFw4YNw7p16wAAN910ky9+D73CVYWJiIiCz+tQM3fuXFy4cAFr1qyBwWDA+PHjsWfPHtdA3/LycshkLQ1A06ZNw4svvohVq1bhvvvuQ3p6Onbt2oUxY8a4ytx9990wmUxYunQpampqMGPGDOzZswcajcZVZt26dVAoFLj11lvR2NiIzMxMvPfee4iLi+vN+fsEp3QTEREFnyB6cwWsMGY0GqHT6VBbW+vz8TV7v6nEkhcOYNwQHV7/wwyfHpuIiEJXU1MTSktLkZaW5vaHOHmns9+jN9/fvPaTD7jWqWH3ExERUdAw1PgABwoTEREFH0ONDzhbauqarLDaQnOVRSIior6OocYHdBEtF+2qaWQXFBERUTAw1PiAQi5DjEaaSMZVhYmIKBxceeWVWL58uc+Ot3DhQtdSLMHCUOMj8VquKkxERBRMDDU+4hosbGJLDRERhbaFCxfiww8/xFNPPQVBECAIAsrKynD06FFcc801iIqKQmJiIm699VZUVVW59vvXv/6FjIwMREREYMCAAcjOzobJZML999+Pbdu24fXXX3cd74MPPgj4eYXFBS3DQcu0boYaIqJ+SxSB5obg/GxlJCAI3Sr61FNP4bvvvsOYMWPw4IMPSrsrlZg6dSp++9vf4q9//SsaGxtxzz334Oabb8Z7772Hc+fOYd68eVi7di2uv/561NXV4eOPP4YoilixYgW+/fZbGI1GPP/88wCA+Ph4v51qRxhqfIQXtSQiIjQ3AI8kB+dn31cBqLTdKqrT6aBSqRAZGem6ePTDDz+MCRMm4JFHHnGV27JlC1JSUvDdd9+hvr4eVqsVN9xwg+u6ixkZGa6yERERMJvNHV7gOhAYanyEa9UQEVE4++KLL/D+++8jKiqq3WsnT57E1VdfjZ/85CfIyMhATk4Orr76atx4440hcbkiJ4YaH3Fd1NLElhoion5LGSm1mATrZ/dCfX09rrvuOjz22GPtXktKSoJcLsfevXuxb98+vPPOO9i4cSP+/Oc/o6SkBGlpab362b7CUOMjsbyoJRERCUK3u4CCTaVSwWazuZ5
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'], label='train')\n",
"plt.plot(history.history['val_loss'], label='test')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20831, 1)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lstm_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20831,)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"test_y1=test_y.reshape(20831,1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.42300856],\n",
" [0.26651022],\n",
" [0.28093082],\n",
" ...,\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y1"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"results1 = np.broadcast_to(lstm_pred, (20831, 6))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"test_y2 = np.broadcast_to(test_y1, (20831, 6))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# 反归一化\n",
"inv_forecast_y = scaler.inverse_transform(results1)\n",
"inv_test_y = scaler.inverse_transform(test_y2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.63622638e+01, 4.53556200e+01, 5.96057328e+02,\n",
" 2.78607105e+02, 1.00676074e+01, 2.18633342e+00],\n",
" [ 8.42201514e+00, 2.98816195e+01, 3.75644883e+02,\n",
" 1.75667855e+02, 6.34294548e+00, 1.37746668e+00],\n",
" [ 9.15367247e+00, 3.13074773e+01, 3.95954874e+02,\n",
" 1.85153232e+02, 6.68615591e+00, 1.45200002e+00],\n",
" ...,\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_test_y"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test RMSE: 0.221\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABP4AAAKTCAYAAACJusZ+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eZwkR3Utjt/IrO6eRZrRviLQgkDsZkeYzXgBbLw9288/Aw/jLzbYzzwQhocXwMZgkG2wAQMGHmA2AWJfhUBIRhKS0K7Rvo5mJI1Go9m37pnuqsz4/VEVETejsrrjRNRUdk3f8/nwUU+TWZGdlRlx49xzz1Vaa00CgUAgEAgEAoFAIBAIBAKB4KBC1vQFCAQCgUAgEAgEAoFAIBAIBILhQ4g/gUAgEAgEAoFAIBAIBAKB4CCEEH8CgUAgEAgEAoFAIBAIBALBQQgh/gQCgUAgEAgEAoFAIBAIBIKDEEL8CQQCgUAgEAgEAoFAIBAIBAchhPgTCAQCgUAgEAgEAoFAIBAIDkII8ScQCAQCgUAgEAgEAoFAIBAchGiNesCyLGnjxo106KGHklJq1MMLBAKBQCAQCAQCgUAgEAgEYw2tNe3Zs4dOOOEEyrLBur6RE38bN26kk046adTDCgQCgUAgEAgEAoFAIBAIBAcVHnjgAXrEIx4x8P8fOfF36KGHElH3wlatWjXq4QUCgUAgEAgEAoFAIBAIBIKxxu7du+mkk06yPNsgjJz4M+W9q1atEuJPIBAIBAKBQCAQCAQCgUAgiMRCNnrS3EMgEAgEAoFAIBAIBAKBQCA4CCHEn0AgEAgEAoFAIBAIBAKBQHAQQog/gUAgEAgEAoFAIBAIBAKB4CDEyD3+BAKBQCAQCAQCgUAgEAgESw9FUVC73W76MsYCExMTlOd58ucI8ScQCAQCgUAgEAgEAoFAIDhg0FrTpk2baOfOnU1fyljhsMMOo+OOO27BBh7zQYg/gUAgEAgEAoFAIBAIBALBAYMh/Y455hhasWJFEpG1FKC1ppmZGdq8eTMRER1//PHRnyXEn0AgEAgEAoFAIBAIBAKB4ICgKApL+h155JFNX87YYPny5UREtHnzZjrmmGOiy36luYdAIBAIBAKBQCAQCAQCgeCAwHj6rVixouErGT+Ye5biiyjEn0AgEAgEAoFAIBAIBAKB4IBCyntxDOOeCfEnEAgEAoFAIBAIBAKBQCAQHIQQ4k8gEAgEAoFAIBAIBAKBQCA4CCHEn0AgEAgEAoFAIBAIBAKBQHAQQog/gUAgEAgEAoFAIBAIBAKBwMOLXvQiOuuss5q+jCQI8ScQCAQCgUAgEAgEAoFAIBCA0FpTp9Np+jLmhRB/AoFAIBAIBAKBQCAQCASCkUFrTTNznZH/T2sdfI2vec1r6JJLLqEPf/jDpJQipRR97nOfI6UUnX/++fT0pz+dpqam6LLLLqPXvOY19Du/8zuV88866yx60YteZP9dliWdffbZdMopp9Dy5cvpKU95Cn3jG98Y0h0djNYBH0EgEAgEAoFAIBAIBAKBQCDoYV+7oMf//Y9HPu5t734JrZgMo8I+/OEP01133UVPfOIT6d3vfjcREd16661ERPQ3f/M39IEPfIBOPfVUOvzww4M+7+yzz6ZzzjmHPvGJT9Dpp59Ol156Kb3qVa+io48+ml74whfG/UEBEOJPIBAIBAKBQCAQCAQCgUAgYFi9ejVNTk7SihUr6LjjjiMiojvuuIOIiN797nfTr/7qrwZ/1uzsLL3vfe+jCy+8kM4880wiIjr11FPpsssuo09+8pNC/AkEAoFAIBAIBAKBQCAQCA4OLJ/I6bZ3v6SRcYeBZzzjGdDx99xzD83MzPSRhXNzc/TUpz51KNc0CEL8CQQCgUAgEAgEAoFAIBAIRgalVHDJ7WLEypUrK//OsqzPP7Ddbtuf9+7dS0RE5513Hp144omV46ampg7QVXYxvndZIBAIBAKBQCAQCAQCgUAgOECYnJykoigWPO7oo4+mW265pfK7NWvW0MTEBBERPf7xj6epqSm6//77D2hZbx2E+BMIBAKBQCAQCAQCgUAgEAg8nHzyyXTVVVfR+vXr6ZBDDqGyLGuPe/GLX0zvf//76Qtf+AKdeeaZdM4559Att9xiy3gPPfRQeutb30pvfvObqSxLet7znke7du2iyy+/nFatWkV//Md/fMD+huyAfbJAIBAIBAKBQCAQCAQCgUAwpnjrW99KeZ7T4x//eDr66KPp/vvvrz3uJS95Cb3zne+kt73tbfTMZz6T9uzZQ69+9asrx7znPe+hd77znXT22WfT4x73OHrpS19K5513Hp1yyikH9G9Q2i9CPsDYvXs3rV69mnbt2kWrVq0a5dACgUAgEAgEAoFAIBAIBIIRYv/+/bRu3To65ZRTaNmyZU1fzlhhvnsXyq+J4k8gEAgEAoFAIBAIBAKBQCA4CCHEn0AgECyAotT0t9+6ib52zQNNX4pAIBAIBAKBQCAQCATBEOJPIBAIFsAFt26ir1z9AL3tmzc1fSkCgUAgEAgEAoFAIBAEQ4g/gUAgWAA797WbvgSBQCAQCAQCgUAgEAhgCPEnEAgEAoFAIBAIBAKBQCAQHIQQ4k8gEAgA7G8XTV+CQCAQCAQCgUAgEAgEQRDiTyAQCAB8/or1TV+CQCAQCAQCgUAgEAgEQRDiTyAQCACI359AIBAIBAKBQCAQCMYFQvwJBAIBAK2bvgKBQCAQCAQCgUAgEAjCIMSfQCAQANAkzJ9AIBAIBAKBQCAQCMYDEPH3rne9i5RSlf+dccYZB+raBAKBQCAQCAQCgUAgEAgEgrHA3Nxc05fQB1jx94QnPIEeeugh+7/LLrvsQFyXQCAQLE6I4E8gEAgEAoFAIBAIlgRe9KIX0Rve8AZ6wxveQKtXr6ajjjqK3vnOd5LueUCdfPLJ9J73vIde/epX06pVq+h1r3sdERFddtll9PznP5+WL19OJ510Er3xjW+k6enpRv4GmPhrtVp03HHH2f8dddRRB+K6BAKBQCAQCAQHObTW9N01D9K9W/Y2fSkCgUAgEAhGCa2J5qZH/78I0/bPf/7z1Gq16Oqrr6YPf/jD9O///u/06U9/2v7/H/jAB+gpT3kK3XDDDfTOd76T1q5dSy996Uvp937v9+imm26ir371q3TZZZfRG97whmHewWC00BPuvvtuOuGEE2jZsmV05pln0tlnn02PfOQjBx4/OztLs7Oz9t+7d++Ou1KBQCBYBBDBn0AgEAwP5938EL3p3DVERLT+n3+j2YsRCAQCgUAwOrRniN53wujH/buNRJMroVNOOukk+uAHP0hKKXrsYx9LN998M33wgx+kP/uzPyMiohe/+MX0lre8xR7/p3/6p/TKV76SzjrrLCIiOv300+k//uM/6IUvfCF9/OMfp2XLlg3tzwkBpPh79rOfTZ/73OfoRz/6EX384x+ndevW0fOf/3zas2fPwHPOPvtsWr16tf3fSSedlHzRAoFAIBAIBILxx/X37Wz6EgQCgUAgEAjmxXOe8xxSStl/n3nmmXT33XdTURRERPSMZzyjcvyNN95In/vc5+iQQw6x/3vJS15CZVnSunXrRnrtRKDi72Uve5n9+clPfjI9+9nPpkc96lH0ta99jV772tfWnvO3f/u39Fd/9Vf237t37xbyTyAQCAQCgUAgEAgEAoFgqWJiRVd918S4Q8bKlVUF4d69e+n1r389vfGNb+w7dr6K2QMFuNSX47DDDqPHPOYxdM899ww8ZmpqiqamplKGEQgEAoFAIBAchGDJc4FAIBAIBEsJSsElt03hqquuqvz7yiuvpNNPP53yPK89/mlPexrddttt9OhHP3oUl7cg4OYeHHv37qW1a9fS8ccfP6zrEQgEgkUNHWEGKxAIBIJ6CO8nEAgEAoFgseP++++nv/qrv6I777yTvvKVr9BHPvIRetOb3jTw+L/+67+mK664gt7whjfQmjVr6O6776bvfve749Hc461vfSv95m/+Jj3qUY+ijRs30j/8wz9Qnuf0R3/0Rwfq+gQCgYCIiIpS0+u/eB097vhD6S2/9timLwd
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 计算均方根误差\n",
"rmse = sqrt(mean_squared_error(inv_test_y[:,5], inv_forecast_y[:,5]))\n",
"print('Test RMSE: %.3f' % rmse)\n",
"#画图\n",
"plt.figure(figsize=(16,8))\n",
"plt.plot(inv_test_y[:,5], label='true')\n",
"plt.plot(inv_forecast_y[:,5], label='pre')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_squared_error: 0.001827188351681114\n",
"mean_absolute_error: 0.013142039763722593\n",
"rmse: 0.04274562377227772\n",
"r2 score: 0.9903982756173205\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error # 评价指标\n",
"# 使用sklearn调用衡量线性回归的MSE 、 RMSE、 MAE、r2\n",
"from math import sqrt\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import r2_score\n",
"print('mean_squared_error:', mean_squared_error(lstm_pred, test_y)) # mse)\n",
"print(\"mean_absolute_error:\", mean_absolute_error(lstm_pred, test_y)) # mae\n",
"print(\"rmse:\", sqrt(mean_squared_error(lstm_pred,test_y)))\n",
"#r2对比区域\n",
"print(\"r2 score:\", r2_score(inv_test_y[5000:10000], inv_forecast_y[5000:10000]))#预测50天数据"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame(inv_test_y[:], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df1.to_csv('test.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"df2 = pd.DataFrame(inv_forecast_y[:], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df2.to_csv('forecast.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}