938 lines
377 KiB
Plaintext
938 lines
377 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\computation\\expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.8.3' currently installed).\n",
|
|||
|
" from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
|
|||
|
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
|
|||
|
" from pandas.core import (\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from math import sqrt\n",
|
|||
|
"from numpy import concatenate\n",
|
|||
|
"from matplotlib import pyplot\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.preprocessing import MinMaxScaler\n",
|
|||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
|||
|
"from sklearn.metrics import mean_squared_error\n",
|
|||
|
"from tensorflow.keras import Sequential\n",
|
|||
|
"\n",
|
|||
|
"from tensorflow.keras.layers import Dense\n",
|
|||
|
"from tensorflow.keras.layers import LSTM\n",
|
|||
|
"from tensorflow.keras.layers import Dropout\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"import matplotlib.pyplot as plt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"这段代码是一个函数 time_series_to_supervised,它用于将时间序列数据转换为监督学习问题的数据集。下面是该函数的各个部分的含义:\n",
|
|||
|
"\n",
|
|||
|
"data: 输入的时间序列数据,可以是列表或2D NumPy数组。\n",
|
|||
|
"n_in: 作为输入的滞后观察数,即用多少个时间步的观察值作为输入。默认值为96,表示使用前96个时间步的观察值作为输入。\n",
|
|||
|
"n_out: 作为输出的观测数量,即预测多少个时间步的观察值。默认值为10,表示预测未来10个时间步的观察值。\n",
|
|||
|
"dropnan: 布尔值,表示是否删除具有NaN值的行。默认为True,即删除具有NaN值的行。\n",
|
|||
|
"函数首先检查输入数据的维度,并初始化一些变量。然后,它创建一个新的DataFrame对象 df 来存储输入数据,并保存原始的列名。接着,它创建了两个空列表 cols 和 names,用于存储新的特征列和列名。\n",
|
|||
|
"\n",
|
|||
|
"接下来,函数开始构建特征列和对应的列名。首先,它将原始的观察序列添加到 cols 列表中,并将其列名添加到 names 列表中。然后,它依次将滞后的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t-滞后时间)。这样就创建了输入特征的部分。\n",
|
|||
|
"\n",
|
|||
|
"接着,函数开始构建输出特征的部分。它依次将未来的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t+未来时间)。\n",
|
|||
|
"\n",
|
|||
|
"最后,函数将所有的特征列拼接在一起,构成一个新的DataFrame对象 agg。如果 dropnan 参数为True,则删除具有NaN值的行。最后,函数返回处理后的数据集 agg。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def time_series_to_supervised(data, n_in=96, n_out=10,dropnan=True):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" :param data:作为列表或2D NumPy数组的观察序列。需要。\n",
|
|||
|
" :param n_in:作为输入的滞后观察数(X)。值可以在[1..len(数据)]之间可选。默认为1。\n",
|
|||
|
" :param n_out:作为输出的观测数量(y)。值可以在[0..len(数据)]之间。可选的。默认为1。\n",
|
|||
|
" :param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。\n",
|
|||
|
" :return:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" n_vars = 1 if type(data) is list else data.shape[1]\n",
|
|||
|
" df = pd.DataFrame(data)\n",
|
|||
|
" origNames = df.columns\n",
|
|||
|
" cols, names = list(), list()\n",
|
|||
|
" cols.append(df.shift(0))\n",
|
|||
|
" names += [('%s' % origNames[j]) for j in range(n_vars)]\n",
|
|||
|
" n_in = max(0, n_in)\n",
|
|||
|
" for i in range(n_in, 0, -1):\n",
|
|||
|
" time = '(t-%d)' % i\n",
|
|||
|
" cols.append(df.shift(i))\n",
|
|||
|
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
|
|||
|
" n_out = max(n_out, 0)\n",
|
|||
|
" for i in range(1, n_out+1):\n",
|
|||
|
" time = '(t+%d)' % i\n",
|
|||
|
" cols.append(df.shift(-i))\n",
|
|||
|
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
|
|||
|
" agg = pd.concat(cols, axis=1)\n",
|
|||
|
" agg.columns = names\n",
|
|||
|
" if dropnan:\n",
|
|||
|
" agg.dropna(inplace=True)\n",
|
|||
|
" return agg"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Temp Humidity GHI DHI Rainfall Power\n",
|
|||
|
"0 19.779453 40.025826 3.232706 1.690531 0.0 0.0\n",
|
|||
|
"1 19.714937 39.605961 3.194991 1.576346 0.0 0.0\n",
|
|||
|
"2 19.549330 39.608631 3.070866 1.576157 0.0 0.0\n",
|
|||
|
"3 19.405870 39.680702 3.038623 1.482489 0.0 0.0\n",
|
|||
|
"4 19.387363 39.319881 2.656474 1.134153 0.0 0.0\n",
|
|||
|
"(104256, 6)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 加载数据\n",
|
|||
|
"path1 = r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\data6.csv\"#数据所在路径\n",
|
|||
|
"#我的数据是excel表,若是csv文件用pandas的read_csv()函数替换即可。\n",
|
|||
|
"datas1 = pd.DataFrame(pd.read_csv(path1))\n",
|
|||
|
"#我只取了data表里的第3、23、16、17、18、19、20、21、27列,如果取全部列的话这一行可以去掉\n",
|
|||
|
"# data1 = datas1.iloc[:,np.r_[3,23,16:22,27]]\n",
|
|||
|
"data1=datas1.interpolate()\n",
|
|||
|
"values1 = data1.values\n",
|
|||
|
"print(data1.head())\n",
|
|||
|
"print(data1.shape)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# data2= data1.drop(['date','Air_P','RH'], axis = 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"(104256, 6)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 使用MinMaxScaler进行归一化\n",
|
|||
|
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
|
|||
|
"scaledData1 = scaler.fit_transform(data1)\n",
|
|||
|
"print(scaledData1.shape)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" 0 1 2 3 4 5 0(t-96) \\\n",
|
|||
|
"96 0.555631 0.349673 0.190042 0.040558 0.0 0.236302 0.490360 \n",
|
|||
|
"97 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 0.489088 \n",
|
|||
|
"98 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 0.485824 \n",
|
|||
|
"99 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 0.482997 \n",
|
|||
|
"100 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 0.482632 \n",
|
|||
|
"\n",
|
|||
|
" 1(t-96) 2(t-96) 3(t-96) ... 2(t-1) 3(t-1) 4(t-1) 5(t-1) \\\n",
|
|||
|
"96 0.369105 0.002088 0.002013 ... 0.166009 0.036794 0.0 0.214129 \n",
|
|||
|
"97 0.364859 0.002061 0.001839 ... 0.190042 0.040558 0.0 0.236302 \n",
|
|||
|
"98 0.364886 0.001973 0.001839 ... 0.211335 0.044613 0.0 0.258204 \n",
|
|||
|
"99 0.365615 0.001950 0.001697 ... 0.229657 0.047549 0.0 0.279860 \n",
|
|||
|
"100 0.361965 0.001679 0.001167 ... 0.247775 0.053347 0.0 0.301336 \n",
|
|||
|
"\n",
|
|||
|
" 0(t+1) 1(t+1) 2(t+1) 3(t+1) 4(t+1) 5(t+1) \n",
|
|||
|
"96 0.564819 0.315350 0.211335 0.044613 0.0 0.258204 \n",
|
|||
|
"97 0.576854 0.288321 0.229657 0.047549 0.0 0.279860 \n",
|
|||
|
"98 0.581973 0.268243 0.247775 0.053347 0.0 0.301336 \n",
|
|||
|
"99 0.586026 0.264586 0.266058 0.057351 0.0 0.322851 \n",
|
|||
|
"100 0.590772 0.258790 0.282900 0.060958 0.0 0.343360 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 588 columns]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"n_steps_in =96 #历史时间长度\n",
|
|||
|
"n_steps_out=1#预测时间长度\n",
|
|||
|
"processedData1 = time_series_to_supervised(scaledData1,n_steps_in,n_steps_out)\n",
|
|||
|
"print(processedData1.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"data_x = processedData1.loc[:,'0(t-96)':'5(t-1)']#去除power剩下的做标签列\n",
|
|||
|
"data_y = processedData1.loc[:,'5']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(104159, 576)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"data_x.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"96 0.236302\n",
|
|||
|
"97 0.258204\n",
|
|||
|
"98 0.279860\n",
|
|||
|
"99 0.301336\n",
|
|||
|
"100 0.322851\n",
|
|||
|
" ... \n",
|
|||
|
"104250 0.000000\n",
|
|||
|
"104251 0.000000\n",
|
|||
|
"104252 0.000000\n",
|
|||
|
"104253 0.000000\n",
|
|||
|
"104254 0.000000\n",
|
|||
|
"Name: 5, Length: 104159, dtype: float64"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"data_y"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(104159,)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"data_y.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"(83328, 96, 6) (83328,) (20831, 96, 6) (20831,)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 7.划分训练集和测试集\n",
|
|||
|
"\n",
|
|||
|
"test_size = int(len(data_x) * 0.2)\n",
|
|||
|
"# 计算训练集和测试集的索引范围\n",
|
|||
|
"train_indices = range(len(data_x) - test_size)\n",
|
|||
|
"test_indices = range(len(data_x) - test_size, len(data_x))\n",
|
|||
|
"\n",
|
|||
|
"# 根据索引范围划分数据集\n",
|
|||
|
"train_X1 = data_x.iloc[train_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
|
|||
|
"test_X1 = data_x.iloc[test_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
|
|||
|
"train_y = data_y.iloc[train_indices].values\n",
|
|||
|
"test_y = data_y.iloc[test_indices].values\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# # 多次运行代码时希望得到相同的数据分割,可以设置 random_state 参数为一个固定的整数值\n",
|
|||
|
"# train_X1,test_X1, train_y, test_y = train_test_split(data_x.values, data_y.values, test_size=0.2, random_state=343)\n",
|
|||
|
"# reshape input to be 3D [samples, timesteps, features]\n",
|
|||
|
"train_X = train_X1.reshape((train_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
|
|||
|
"test_X = test_X1.reshape((test_X1.shape[0], n_steps_in,scaledData1.shape[1]))\n",
|
|||
|
"print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)\n",
|
|||
|
"# 使用train_test_split函数划分训练集和测试集,测试集的比重是40%。\n",
|
|||
|
"# 然后将train_X1、test_X1进行一个升维,变成三维,维数分别是[samples,timesteps,features]。\n",
|
|||
|
"# 打印一下他们的shape:\\\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(83328, 96, 6)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"train_X1.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"WARNING:tensorflow:From d:\\Anaconda3\\lib\\site-packages\\keras\\src\\backend\\tensorflow\\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n",
|
|||
|
"</pre>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"\u001b[1mModel: \"functional\"\u001b[0m\n"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|||
|
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
|
|||
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|||
|
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">6</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|||
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ conv1d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv1D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">832</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ max_pooling1d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|||
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling1D</span>) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ bidirectional │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">49,920</span> │ max_pooling1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|||
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ attention_with_imp… │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,304</span> │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|||
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AttentionWithImpr…</span> │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|||
|
"│ │ <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)] │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ global_average_poo… │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ attention_with_i… │\n",
|
|||
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePool…</span> │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ dense_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">129</span> │ global_average_p… │\n",
|
|||
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
|
|||
|
"</pre>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|||
|
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
|
|||
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|||
|
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|||
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m832\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ max_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|||
|
"│ (\u001b[38;5;33mMaxPooling1D\u001b[0m) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ bidirectional │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,920\u001b[0m │ max_pooling1d[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|||
|
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ attention_with_imp… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m66,304\u001b[0m │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|||
|
"│ (\u001b[38;5;33mAttentionWithImpr…\u001b[0m │ \u001b[38;5;34m128\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|||
|
"│ │ \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention_with_i… │\n",
|
|||
|
"│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n",
|
|||
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|||
|
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │ global_average_p… │\n",
|
|||
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
|
|||
|
"</pre>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
|
|||
|
"</pre>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
|||
|
"</pre>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import tensorflow as tf\n",
|
|||
|
"from tensorflow.keras.layers import Input, Conv1D, Bidirectional, GlobalAveragePooling1D, Dense, GRU, MaxPooling1D\n",
|
|||
|
"from tensorflow.keras.models import Model\n",
|
|||
|
"from tensorflow.keras.initializers import RandomUniform\n",
|
|||
|
"class AttentionWithImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
|
|||
|
" def __init__(self, d_model, num_heads, max_len=5000):\n",
|
|||
|
" super(AttentionWithImproveRelativePositionEncoding, self).__init__()\n",
|
|||
|
" self.num_heads = num_heads\n",
|
|||
|
" self.d_model = d_model\n",
|
|||
|
" self.max_len = max_len\n",
|
|||
|
" self.wq = tf.keras.layers.Dense(d_model)\n",
|
|||
|
" self.wk = tf.keras.layers.Dense(d_model)\n",
|
|||
|
" self.wv = tf.keras.layers.Dense(d_model)\n",
|
|||
|
" self.dense = tf.keras.layers.Dense(d_model)\n",
|
|||
|
" self.position_encoding = ImproveRelativePositionEncoding(d_model)\n",
|
|||
|
"\n",
|
|||
|
" def call(self, v, k, q, mask):\n",
|
|||
|
" batch_size = tf.shape(q)[0]\n",
|
|||
|
" q = self.wq(q)\n",
|
|||
|
" k = self.wk(k)\n",
|
|||
|
" v = self.wv(v)\n",
|
|||
|
"\n",
|
|||
|
" # 添加位置编码\n",
|
|||
|
" k += self.position_encoding (k)\n",
|
|||
|
" q += self.position_encoding (q)\n",
|
|||
|
"\n",
|
|||
|
" q = self.split_heads(q, batch_size)\n",
|
|||
|
" k = self.split_heads(k, batch_size)\n",
|
|||
|
" v = self.split_heads(v, batch_size)\n",
|
|||
|
"\n",
|
|||
|
" scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)\n",
|
|||
|
" scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])\n",
|
|||
|
" concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))\n",
|
|||
|
" output = self.dense(concat_attention)\n",
|
|||
|
" return output, attention_weights\n",
|
|||
|
"\n",
|
|||
|
" def split_heads(self, x, batch_size):\n",
|
|||
|
" x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_model // self.num_heads))\n",
|
|||
|
" return tf.transpose(x, perm=[0, 2, 1, 3])\n",
|
|||
|
"\n",
|
|||
|
" def scaled_dot_product_attention(self, q, k, v, mask):\n",
|
|||
|
" matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
|
|||
|
" dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
|
|||
|
" scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
|
|||
|
"\n",
|
|||
|
" if mask is not None:\n",
|
|||
|
" scaled_attention_logits += (mask * -1e9)\n",
|
|||
|
"\n",
|
|||
|
" attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
|
|||
|
" output = tf.matmul(attention_weights, v)\n",
|
|||
|
" return output, attention_weights\n",
|
|||
|
"\n",
|
|||
|
"class ImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
|
|||
|
" def __init__(self, d_model, max_len=5000):\n",
|
|||
|
" super(ImproveRelativePositionEncoding, self).__init__()\n",
|
|||
|
" self.max_len = max_len\n",
|
|||
|
" self.d_model = d_model\n",
|
|||
|
" # 引入可变化的参数u和v进行线性变化\n",
|
|||
|
" self.u = self.add_weight(shape=(self.d_model,),\n",
|
|||
|
" initializer=RandomUniform(),\n",
|
|||
|
" trainable=True)\n",
|
|||
|
" self.v = self.add_weight(shape=(self.d_model,),\n",
|
|||
|
" initializer=RandomUniform(),\n",
|
|||
|
" trainable=True)\n",
|
|||
|
" def call(self, inputs):\n",
|
|||
|
" seq_length = inputs.shape[1]\n",
|
|||
|
" pos_encoding = self.relative_positional_encoding(seq_length, self.d_model)\n",
|
|||
|
" \n",
|
|||
|
" # 调整原始的相对位置编码公式,将u和v参数融入其中\n",
|
|||
|
" pe_with_params = pos_encoding * self.u+ pos_encoding * self.v\n",
|
|||
|
" return inputs + pe_with_params\n",
|
|||
|
"\n",
|
|||
|
" def relative_positional_encoding(self, position, d_model):\n",
|
|||
|
" pos = tf.range(position, dtype=tf.float32)\n",
|
|||
|
" i = tf.range(d_model, dtype=tf.float32)\n",
|
|||
|
" \n",
|
|||
|
" angles = 1 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(d_model, tf.float32))\n",
|
|||
|
" angle_rads = tf.einsum('i,j->ij', pos, angles)\n",
|
|||
|
" #保留了sinous机制\n",
|
|||
|
" # Apply sin to even indices; 2i\n",
|
|||
|
" angle_rads_sin = tf.sin(angle_rads[:, 0::2])\n",
|
|||
|
" # Apply cos to odd indices; 2i+1\n",
|
|||
|
" angle_rads_cos = tf.cos(angle_rads[:, 1::2])\n",
|
|||
|
"\n",
|
|||
|
" pos_encoding = tf.stack([angle_rads_sin, angle_rads_cos], axis=2)\n",
|
|||
|
" pos_encoding = tf.reshape(pos_encoding, [1, position, d_model])\n",
|
|||
|
"\n",
|
|||
|
" return pos_encoding\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads):\n",
|
|||
|
" inputs = Input(shape=input_shape)\n",
|
|||
|
" # CNN layer\n",
|
|||
|
" cnn_layer = Conv1D(filters=64, kernel_size=2, activation='relu')(inputs)\n",
|
|||
|
" cnn_layer = MaxPooling1D(pool_size=1)(cnn_layer)\n",
|
|||
|
" gru_output = Bidirectional(GRU(gru_units, return_sequences=True))(cnn_layer)\n",
|
|||
|
" \n",
|
|||
|
" # Apply Self-Attention\n",
|
|||
|
" self_attention =AttentionWithImproveRelativePositionEncoding(d_model=gru_units*2, num_heads=num_heads)\n",
|
|||
|
" gru_output, _ = self_attention(gru_output, gru_output, gru_output, mask=None)\n",
|
|||
|
" \n",
|
|||
|
" pool1 = GlobalAveragePooling1D()(gru_output)\n",
|
|||
|
" output = Dense(1)(pool1)\n",
|
|||
|
" \n",
|
|||
|
" return Model(inputs=inputs, outputs=output)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"input_shape = (96, 6)\n",
|
|||
|
"gru_units = 64\n",
|
|||
|
"num_heads = 8\n",
|
|||
|
"\n",
|
|||
|
"# Create model\n",
|
|||
|
"model = PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads)\n",
|
|||
|
"model.compile(optimizer='adam', loss='mse')\n",
|
|||
|
"model.summary()\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Epoch 1/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 61ms/step - loss: 0.0187 - val_loss: 0.0021\n",
|
|||
|
"Epoch 2/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 62ms/step - loss: 0.0014 - val_loss: 0.0025\n",
|
|||
|
"Epoch 3/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m84s\u001b[0m 64ms/step - loss: 0.0013 - val_loss: 0.0021\n",
|
|||
|
"Epoch 4/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m87s\u001b[0m 67ms/step - loss: 0.0012 - val_loss: 0.0020\n",
|
|||
|
"Epoch 5/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m76s\u001b[0m 58ms/step - loss: 0.0013 - val_loss: 0.0019\n",
|
|||
|
"Epoch 6/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m71s\u001b[0m 55ms/step - loss: 0.0011 - val_loss: 0.0020\n",
|
|||
|
"Epoch 7/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m70s\u001b[0m 54ms/step - loss: 0.0011 - val_loss: 0.0019\n",
|
|||
|
"Epoch 8/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m81s\u001b[0m 62ms/step - loss: 0.0011 - val_loss: 0.0020\n",
|
|||
|
"Epoch 9/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 63ms/step - loss: 0.0012 - val_loss: 0.0019\n",
|
|||
|
"Epoch 10/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m85s\u001b[0m 65ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"Epoch 11/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m88s\u001b[0m 68ms/step - loss: 0.0011 - val_loss: 0.0019\n",
|
|||
|
"Epoch 12/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m90s\u001b[0m 69ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"Epoch 13/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 61ms/step - loss: 0.0011 - val_loss: 0.0020\n",
|
|||
|
"Epoch 14/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 64ms/step - loss: 0.0011 - val_loss: 0.0019\n",
|
|||
|
"Epoch 15/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 63ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"Epoch 16/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 61ms/step - loss: 0.0011 - val_loss: 0.0019\n",
|
|||
|
"Epoch 17/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m81s\u001b[0m 63ms/step - loss: 0.0010 - val_loss: 0.0020\n",
|
|||
|
"Epoch 18/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m77s\u001b[0m 59ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"Epoch 19/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m78s\u001b[0m 60ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"Epoch 20/100\n",
|
|||
|
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m77s\u001b[0m 59ms/step - loss: 0.0011 - val_loss: 0.0018\n",
|
|||
|
"\u001b[1m651/651\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 17ms/step\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Compile and train the model\n",
|
|||
|
"model.compile(optimizer='adam', loss='mean_squared_error')\n",
|
|||
|
"from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
|
|||
|
"\n",
|
|||
|
"# 定义早停机制\n",
|
|||
|
"early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')\n",
|
|||
|
"\n",
|
|||
|
"# 拟合模型,并添加早停机制和模型检查点\n",
|
|||
|
"history = model.fit(train_X, train_y, epochs=100, batch_size=64, validation_data=(test_X, test_y), \n",
|
|||
|
" callbacks=[early_stopping])\n",
|
|||
|
"# 预测\n",
|
|||
|
"lstm_pred = model.predict(test_X)\n",
|
|||
|
"# 将预测结果的形状修改为与原始数据相同的形状"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOSklEQVR4nO3de1xUZeI/8M+ZOwwwgCgDikKGqYl3JdT9WrtsWG4rtZW5u3lZV9v9ta0u2cVWpdtGabXm5Ru1m2m739LcynbLLKN7EuatstLMBVFxUCQYGGCGmTm/P87MwMBwGZgrfN6v17xm5sxzDs9hHOfDczuCKIoiiIiIiMKcLNgVICIiIvIFhhoiIiLqExhqiIiIqE9gqCEiIqI+gaGGiIiI+gSGGiIiIuoTGGqIiIioT2CoISIioj5BEewKBIrdbkdFRQWio6MhCEKwq0NERETdIIoi6urqkJycDJms87aYfhNqKioqkJKSEuxqEBERUQ+cPn0aQ4YM6bRMvwk10dHRAKRfSkxMTJBrQ0RERN1hNBqRkpLi+h7vTL8JNc4up5iYGIYaIiKiMNOdoSMcKExERER9AkMNERER9QkMNURERNQn9JsxNURERP4iiiKsVitsNluwqxJ25HI5FAqFT5ZbYaghIiLqBYvFgnPnzqGhoSHYVQlbkZGRSEpKgkql6tVxGGqIiIh6yG63o7S0FHK5HMnJyVCpVFzg1QuiKMJiseDChQsoLS1Fenp6lwvsdYahhoiIqIcsFgvsdjtSUlIQGRkZ7OqEpYiICCiVSpw6dQoWiwUajabHx+JAYSIiol7qTesC+e73x3eBiIiI+gSGGiIiIuoTGGqIiIioV1JTU7F+/fpgV4MDhYmIiPqjK6+8EuPHj/dJGPn888+h1Wp7X6leYqjppe8q6/Dy56cxIEqN3185PNjVISIi8glRFGGz2aBQdB0VBg4cGIAadY3dT710rrYJf/+kFK8fORvsqhARUZCJoogGizUoN1EUu13PhQsX4sMPP8RTTz0FQRAgCAK2bt0KQRDw1ltvYdKkSVCr1fjkk09w8uRJzJkzB4mJiYiKisKUKVPw7rvvuh2vbfeTIAj4+9//juuvvx6RkZFIT0/Hv//9b1/9mjvElppeiotUAgBqGpqDXBMiIgq2xmYbRq95Oyg/+5sHcxCp6t7X+lNPPYXvvvsOY8aMwYMPPggA+PrrrwEA9957Lx5//HFccskliIuLw+nTp3HttdfiL3/5C9RqNV544QVcd911OH78OIYOHdrhz3jggQewdu1arFu3Dhs3bsSvfvUrnDp1CvHx8b0/2Q6wpaaX4iKlJZ1/aLAEuSZERETdo9PpoFKpEBkZCb1eD71eD7lcDgB48MEH8dOf/hTDhw9HfHw8xo0bh9tuuw1jxoxBeno6HnroIQwfPrzLlpeFCxdi3rx5uPTSS/HII4+gvr4e+/fv9+t59ailZvPmzVi3bh0MBgPGjRuHjRs3YurUqR2W37lzJ1avXo2ysjKkp6fjsccew7XXXut6XRRF5Ofn429/+xtqamowffp0PP3000hPT3c7zptvvokHH3wQX375JTQaDWbOnIldu3b15BR8JtbRUmO22tFosSFCJQ9qfYiIKHgilHJ882BO0H62L0yePNnteX19Pe6//368+eabOHfuHKxWKxobG1FeXt7pccaOHet6rNVqERMTg/Pnz/ukjh3xuqVmx44dyMvLQ35+Pg4dOoRx48YhJyenw4ru27cP8+bNw+LFi3H48GHk5uYiNzcXR48edZVZu3YtNmzYgMLCQpSUlECr1SInJwdNTU2uMq+88gpuvfVWLFq0CF988QU+/fRT/PKXv+zBKftWlFoBhUy6zgdba4iI+jdBEBCpUgTl5qtrTrWdxbRixQq89tpreOSRR/Dxxx/jyJEjyMjIgMXS+XeeUqls97ux2+0+qWNHvA41Tz75JJYsWYJFixZh9OjRKCwsRGRkJLZs2eKx/FNPPYVZs2bhrrvuwqhRo/DQQw9h4sSJ2LRpEwCplWb9+vVYtWoV5syZg7Fjx+KFF15ARUWFqxXGarVi2bJlWLduHX73u99hxIgRGD16NG6++eaen7mPCIKAOC27oIiIKLyoVCrYbLYuy3366adYuHAhrr/+emRkZECv16OsrMz/FewBr0KNxWLBwYMHkZ2d3XIAmQzZ2dkoLi72uE9xcbFbeQDIyclxlS8tLYXBYHAro9PpkJmZ6Spz6NAhnD17FjKZDBMmTEBSUhKuueYat9aetsxmM4xGo9vNX5yDhX8wcbAwERGFh9TUVJSUlKCsrAxVVVUdtqKkp6fj1VdfxZEjR/DFF1/gl7/8pd9bXHrKq1BTVVUFm82GxMREt+2JiYkwGAwe9zEYDJ2Wd953Vua///0vAOD+++/HqlWr8MYbbyAuLg5XXnklqqurPf7cgoIC6HQ61y0lJcWbU/VKLAcLExFRmFmxYgXkcjlGjx6NgQMHdjhG5sknn0RcXBymTZuG6667Djk5OZg4cWKAa9s9YTGl25kI//znP+MXv/gFAOD555/HkCFDsHPnTtx2223t9lm5ciXy8vJcz41Go9+CTcu0boYaIiIKDyNGjGjXy7Jw4cJ25VJTU/Hee++5bbv99tvdnrftjvK0Zk5NTU2P6ukNr1pqEhISIJfLUVlZ6ba9srISer3e4z56vb7T8s77zsokJSUBAEaPHu16Xa1W45JLLukwWarVasTExLjd/KVlWje7n4iIiILFq1CjUqkwadIkFBUVubbZ7XYUFRUhKyvL4z5ZWVlu5QFg7969rvJpaWnQ6/VuZYxGI0pKSlxlnCsbHj9+3FWmubkZZWVlGDZsmDen4BfsfiIiIgo+r7uf8vLysGDBAkyePBlTp07F+vXrYTKZsGjRIgDA/PnzMXjwYBQUFAAAli1bhpkzZ+KJJ57A7NmzsX37dhw4cADPPvssAGn20PLly/Hwww8jPT0daWlpWL16NZKTk5GbmwsAiImJwe9+9zvk5+cjJSUFw4YNw7p16wAAN910ky9+D73CVYWJiIiCz+tQM3fuXFy4cAFr1qyBwWDA+PHjsWfPHtdA3/LycshkLQ1A06ZNw4svvohVq1bhvvvuQ3p6Onbt2oUxY8a4ytx9990wmUxYunQpampqMGPGDOzZswcajcZVZt26dVAoFLj11lvR2NiIzMxMvPfee4iLi+vN+fsEp3QTEREFnyB6cwWsMGY0GqHT6VBbW+vz8TV7v6nEkhcOYNwQHV7/wwyfHpuIiEJXU1MTSktLkZaW5vaHOHmns9+jN9/fvPaTD7jWqWH3ExERUdAw1PgABwoTEREFH0ONDzhbauqarLDaQnOVRSIior6OocYHdBEtF+2qaWQXFBERUTAw1PiAQi5DjEaaSMZVhYmIKBxceeWVWL58uc+Ot3DhQtdSLMHCUOMj8VquKkxERBRMDDU+4hosbGJLDRERhbaFCxfiww8/xFNPPQVBECAIAsrKynD06FFcc801iIqKQmJiIm699VZUVVW59vvXv/6FjIwMREREYMCAAcjOzobJZML999+Pbdu24fXXX3cd74MPPgj4eYXFBS3DQcu0boYaIqJ+SxSB5obg/GxlJCAI3Sr61FNP4bvvvsOYMWPw4IMPSrsrlZg6dSp++9vf4q9//SsaGxtxzz334Oabb8Z7772Hc+fOYd68eVi7di2uv/561NXV4eOPP4YoilixYgW+/fZbGI1GPP/88wCA+Ph4v51qRxhqfIQXtSQiIjQ3AI8kB+dn31cBqLTdKqrT6aBSqRAZGem6ePTDDz+MCRMm4JFHHnGV27JlC1JSUvDdd9+hvr4eVqsVN9xwg+u6ixkZGa6yERERMJvNHV7gOhAYanyEa9UQEVE4++KLL/D+++8jKiqq3WsnT57E1VdfjZ/85CfIyMhATk4Orr76atx4440hcbkiJ4YaH3Fd1NLElhoion5LGSm1mATrZ/dCfX09rrvuOjz22GPtXktKSoJcLsfevXuxb98+vPPOO9i4cSP+/Oc/o6SkBGlpab362b7CUOMjsbyoJRERCUK3u4CCTaVSwWazuZ5
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.plot(history.history['loss'], label='train')\n",
|
|||
|
"plt.plot(history.history['val_loss'], label='test')\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(20831, 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"lstm_pred.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(20831,)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"test_y.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_y1=test_y.reshape(20831,1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"array([[0.42300856],\n",
|
|||
|
" [0.26651022],\n",
|
|||
|
" [0.28093082],\n",
|
|||
|
" ...,\n",
|
|||
|
" [0. ],\n",
|
|||
|
" [0. ],\n",
|
|||
|
" [0. ]])"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"test_y1"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"results1 = np.broadcast_to(lstm_pred, (20831, 6))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_y2 = np.broadcast_to(test_y1, (20831, 6))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 反归一化\n",
|
|||
|
"inv_forecast_y = scaler.inverse_transform(results1)\n",
|
|||
|
"inv_test_y = scaler.inverse_transform(test_y2)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"array([[ 1.63622638e+01, 4.53556200e+01, 5.96057328e+02,\n",
|
|||
|
" 2.78607105e+02, 1.00676074e+01, 2.18633342e+00],\n",
|
|||
|
" [ 8.42201514e+00, 2.98816195e+01, 3.75644883e+02,\n",
|
|||
|
" 1.75667855e+02, 6.34294548e+00, 1.37746668e+00],\n",
|
|||
|
" [ 9.15367247e+00, 3.13074773e+01, 3.95954874e+02,\n",
|
|||
|
" 1.85153232e+02, 6.68615591e+00, 1.45200002e+00],\n",
|
|||
|
" ...,\n",
|
|||
|
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
|
|||
|
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
|
|||
|
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
|
|||
|
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00],\n",
|
|||
|
" [-5.09990072e+00, 3.53003502e+00, 2.91584611e-01,\n",
|
|||
|
" 3.66558254e-01, 0.00000000e+00, 0.00000000e+00]])"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"inv_test_y"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Test RMSE: 0.221\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABP4AAAKTCAYAAACJusZ+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eZwkR3Utjt/IrO6eRZrRviLQgkDsZkeYzXgBbLw9288/Aw/jLzbYzzwQhocXwMZgkG2wAQMGHmA2AWJfhUBIRhKS0K7Rvo5mJI1Go9m37pnuqsz4/VEVETejsrrjRNRUdk3f8/nwUU+TWZGdlRlx49xzz1Vaa00CgUAgEAgEAoFAIBAIBAKB4KBC1vQFCAQCgUAgEAgEAoFAIBAIBILhQ4g/gUAgEAgEAoFAIBAIBAKB4CCEEH8CgUAgEAgEAoFAIBAIBALBQQgh/gQCgUAgEAgEAoFAIBAIBIKDEEL8CQQCgUAgEAgEAoFAIBAIBAchhPgTCAQCgUAgEAgEAoFAIBAIDkII8ScQCAQCgUAgEAgEAoFAIBAchGiNesCyLGnjxo106KGHklJq1MMLBAKBQCAQCAQCgUAgEAgEYw2tNe3Zs4dOOOEEyrLBur6RE38bN26kk046adTDCgQCgUAgEAgEAoFAIBAIBAcVHnjgAXrEIx4x8P8fOfF36KGHElH3wlatWjXq4QUCgUAgEAgEAoFAIBAIBIKxxu7du+mkk06yPNsgjJz4M+W9q1atEuJPIBAIBAKBQCAQCAQCgUAgiMRCNnrS3EMgEAgEAoFAIBAIBAKBQCA4CCHEn0AgEAgEAoFAIBAIBAKBQHAQQog/gUAgEAgEAoFAIBAIBAKB4CDEyD3+BAKBQCAQCAQCgUAgEAgESw9FUVC73W76MsYCExMTlOd58ucI8ScQCAQCgUAgEAgEAoFAIDhg0FrTpk2baOfOnU1fyljhsMMOo+OOO27BBh7zQYg/gUAgEAgEAoFAIBAIBALBAYMh/Y455hhasWJFEpG1FKC1ppmZGdq8eTMRER1//PHRnyXEn0AgEAgEAoFAIBAIBAKB4ICgKApL+h155JFNX87YYPny5UREtHnzZjrmmGOiy36luYdAIBAIBAKBQCAQCAQCgeCAwHj6rVixouErGT+Ye5biiyjEn0AgEAgEAoFAIBAIBAKB4IBCyntxDOOeCfEnEAgEAoFAIBAIBAKBQCAQHIQQ4k8gEAgEAoFAIBAIBAKBQCA4CCHEn0AgEAgEAoFAIBAIBAKBQHAQQog/gUAgEAgEAoFAIBAIBAKBwMOLXvQiOuuss5q+jCQI8ScQCAQCgUAgEAgEAoFAIBCA0FpTp9Np+jLmhRB/AoFAIBAIBAKBQCAQCASCkUFrTTNznZH/T2sdfI2vec1r6JJLLqEPf/jDpJQipRR97nOfI6UUnX/++fT0pz+dpqam6LLLLqPXvOY19Du/8zuV88866yx60YteZP9dliWdffbZdMopp9Dy5cvpKU95Cn3jG98Y0h0djNYBH0EgEAgEAoFAIBAIBAKBQCDoYV+7oMf//Y9HPu5t734JrZgMo8I+/OEP01133UVPfOIT6d3vfjcREd16661ERPQ3f/M39IEPfIBOPfVUOvzww4M+7+yzz6ZzzjmHPvGJT9Dpp59Ol156Kb3qVa+io48+ml74whfG/UEBEOJPIBAIBAKBQCAQCAQCgUAgYFi9ejVNTk7SihUr6LjjjiMiojvuuIOIiN797nfTr/7qrwZ/1uzsLL3vfe+jCy+8kM4880wiIjr11FPpsssuo09+8pNC/AkEAoFAIBAIBAKBQCAQCA4OLJ/I6bZ3v6SRcYeBZzzjGdDx99xzD83MzPSRhXNzc/TUpz51KNc0CEL8CQQCgUAgEAgEAoFAIBAIRgalVHDJ7WLEypUrK//OsqzPP7Ddbtuf9+7dS0RE5513Hp144omV46ampg7QVXYxvndZIBAIBAKBQCAQCAQCgUAgOECYnJykoigWPO7oo4+mW265pfK7NWvW0MTEBBERPf7xj6epqSm6//77D2hZbx2E+BMIBAKBQCAQCAQCgUAgEAg8nHzyyXTVVVfR+vXr6ZBDDqGyLGuPe/GLX0zvf//76Qtf+AKdeeaZdM4559Att9xiy3gPPfRQeutb30pvfvObqSxLet7znke7du2iyy+/nFatWkV//Md/fMD+huyAfbJAIBAIBAKBQCAQCAQCgUAwpnjrW99KeZ7T4x//eDr66KPp/vvvrz3uJS95Cb3zne+kt73tbfTMZz6T9uzZQ69+9asrx7znPe+hd77znXT22WfT4x73OHrpS19K5513Hp1yyikH9G9Q2i9CPsDYvXs3rV69mnbt2kWrVq0a5dACgUAgEAgEAoFAIBAIBIIRYv/+/bRu3To65ZRTaNmyZU1fzlhhvnsXyq+J4k8gEAgEAoFAIBAIBAKBQCA4CCHEn0AgECyAotT0t9+6ib52zQNNX4pAIBAIBAKBQCAQCATBEOJPIBAIFsAFt26ir1z9AL3tmzc1fSkCgUAgEAgEAoFAIBAEQ4g/gUAgWAA797WbvgSBQCAQCAQCgUAgEAhgCPEnEAgEAoFAIBAIBAKBQCAQHIQQ4k8gEAgA7G8XTV+CQCAQCAQCgUAgEAgEQRDiTyAQCAB8/or1TV+CQCAQCAQCgUAgEAgEQRDiTyAQCACI359AIBAIBAKBQCAQCMYFQvwJBAIBAK2bvgKBQCAQCAQCgUAgEAjCIMSfQCAQANAkzJ9AIBAIBAKBQCAQCMYDEPH3rne9i5RSlf+dccYZB+raBAKBQCAQCAQCgUAgEAgEgrHA3Nxc05fQB1jx94QnPIEeeugh+7/LLrvsQFyXQCAQLE6I4E8gEAgEAoFAIBAIlgRe9KIX0Rve8AZ6wxveQKtXr6ajjjqK3vnOd5LueUCdfPLJ9J73vIde/epX06pVq+h1r3sdERFddtll9PznP5+WL19OJ510Er3xjW+k6enpRv4GmPhrtVp03HHH2f8dddRRB+K6BAKBQCAQCAQHObTW9N01D9K9W/Y2fSkCgUAgEAhGCa2J5qZH/78I0/bPf/7z1Gq16Oqrr6YPf/jD9O///u/06U9/2v7/H/jAB+gpT3kK3XDDDfTOd76T1q5dSy996Uvp937v9+imm26ir371q3TZZZfRG97whmHewWC00BPuvvtuOuGEE2jZsmV05pln0tlnn02PfOQjBx4/OztLs7Oz9t+7d++Ou1KBQCBYBBDBn0AgEAwP5938EL3p3DVERLT+n3+j2YsRCAQCgUAwOrRniN53wujH/buNRJMroVNOOukk+uAHP0hKKXrsYx9LN998M33wgx+kP/uzPyMiohe/+MX0lre8xR7/p3/6p/TKV76SzjrrLCIiOv300+k//uM/6IUvfCF9/OMfp2XLlg3tzwkBpPh79rOfTZ/73OfoRz/6EX384x+ndevW0fOf/3zas2fPwHPOPvtsWr16tf3fSSedlHzRAoFAIBAIBILxx/X37Wz6EgQCgUAgEAjmxXOe8xxSStl/n3nmmXT33XdTURRERPSMZzyjcvyNN95In/vc5+iQQw6x/3vJS15CZVnSunXrRnrtRKDi72Uve5n9+clPfjI9+9nPpkc96lH0ta99jV772tfWnvO3f/u39Fd/9Vf237t37xbyTyAQCAQCgUAgEAgEAoFgqWJiRVd918S4Q8bKlVUF4d69e+n1r389vfGNb+w7dr6K2QMFuNSX47DDDqPHPOYxdM899ww8ZmpqiqamplKGEQgEAoFAIBAchGDJc4FAIBAIBEsJSsElt03hqquuqvz7yiuvpNNPP53yPK89/mlPexrddttt9OhHP3oUl7cg4OYeHHv37qW1a9fS8ccfP6zrEQgEgkUNHWEGKxAIBIJ6CO8nEAgEAoFgseP++++nv/qrv6I777yTvvKVr9BHPvIRetOb3jTw+L/+67+mK664gt7whjfQmjVr6O6776bvfve749Hc461vfSv95m/+Jj3qUY+ijRs30j/8wz9Qnuf0R3/0Rwfq+gQCgYCIiIpS0+u/eB097vhD6S2/9timLwd
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1600x800 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 计算均方根误差\n",
|
|||
|
"rmse = sqrt(mean_squared_error(inv_test_y[:,5], inv_forecast_y[:,5]))\n",
|
|||
|
"print('Test RMSE: %.3f' % rmse)\n",
|
|||
|
"#画图\n",
|
|||
|
"plt.figure(figsize=(16,8))\n",
|
|||
|
"plt.plot(inv_test_y[:,5], label='true')\n",
|
|||
|
"plt.plot(inv_forecast_y[:,5], label='pre')\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"mean_squared_error: 0.001827188351681114\n",
|
|||
|
"mean_absolute_error: 0.013142039763722593\n",
|
|||
|
"rmse: 0.04274562377227772\n",
|
|||
|
"r2 score: 0.9903982756173205\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, mean_absolute_error # 评价指标\n",
|
|||
|
"# 使用sklearn调用衡量线性回归的MSE 、 RMSE、 MAE、r2\n",
|
|||
|
"from math import sqrt\n",
|
|||
|
"from sklearn.metrics import mean_absolute_error\n",
|
|||
|
"from sklearn.metrics import mean_squared_error\n",
|
|||
|
"from sklearn.metrics import r2_score\n",
|
|||
|
"print('mean_squared_error:', mean_squared_error(lstm_pred, test_y)) # mse)\n",
|
|||
|
"print(\"mean_absolute_error:\", mean_absolute_error(lstm_pred, test_y)) # mae\n",
|
|||
|
"print(\"rmse:\", sqrt(mean_squared_error(lstm_pred,test_y)))\n",
|
|||
|
"#r2对比区域\n",
|
|||
|
"print(\"r2 score:\", r2_score(inv_test_y[5000:10000], inv_forecast_y[5000:10000]))#预测50天数据"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 43,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"df1 = pd.DataFrame(inv_test_y[:], columns=['column_name'])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 44,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 指定文件路径和文件名,保存DataFrame到CSV文件中\n",
|
|||
|
"df1.to_csv('test.csv', index=False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"df2 = pd.DataFrame(inv_forecast_y[:], columns=['column_name'])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 指定文件路径和文件名,保存DataFrame到CSV文件中\n",
|
|||
|
"df2.to_csv('forecast.csv', index=False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "base",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.13"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|