ICEEMDAN-Solar_power-forecast/iceemdan-筛选-high-ConvBiGruA...

1103 lines
267 KiB
Plaintext
Raw Normal View History

2024-08-01 10:47:40 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\computation\\expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.8.3' currently installed).\n",
" from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
" from pandas.core import (\n"
]
}
],
"source": [
"from math import sqrt\n",
"from numpy import concatenate\n",
"from matplotlib import pyplot\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error\n",
"from tensorflow.keras import Sequential\n",
"\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import LSTM\n",
"from tensorflow.keras.layers import Dropout\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这段代码是一个函数 time_series_to_supervised它用于将时间序列数据转换为监督学习问题的数据集。下面是该函数的各个部分的含义\n",
"\n",
"data: 输入的时间序列数据可以是列表或2D NumPy数组。\n",
"n_in: 作为输入的滞后观察数即用多少个时间步的观察值作为输入。默认值为96表示使用前96个时间步的观察值作为输入。\n",
"n_out: 作为输出的观测数量即预测多少个时间步的观察值。默认值为10表示预测未来10个时间步的观察值。\n",
"dropnan: 布尔值表示是否删除具有NaN值的行。默认为True即删除具有NaN值的行。\n",
"函数首先检查输入数据的维度并初始化一些变量。然后它创建一个新的DataFrame对象 df 来存储输入数据,并保存原始的列名。接着,它创建了两个空列表 cols 和 names用于存储新的特征列和列名。\n",
"\n",
"接下来,函数开始构建特征列和对应的列名。首先,它将原始的观察序列添加到 cols 列表中,并将其列名添加到 names 列表中。然后,它依次将滞后的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t-滞后时间)。这样就创建了输入特征的部分。\n",
"\n",
"接着,函数开始构建输出特征的部分。它依次将未来的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t+未来时间)。\n",
"\n",
"最后函数将所有的特征列拼接在一起构成一个新的DataFrame对象 agg。如果 dropnan 参数为True则删除具有NaN值的行。最后函数返回处理后的数据集 agg。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def time_series_to_supervised(data, n_in=96, n_out=10,dropnan=True):\n",
" \"\"\"\n",
" :param data:作为列表或2D NumPy数组的观察序列。需要。\n",
" :param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。\n",
" :param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。\n",
" :param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。\n",
" :return:\n",
" \"\"\"\n",
" n_vars = 1 if type(data) is list else data.shape[1]\n",
" df = pd.DataFrame(data)\n",
" origNames = df.columns\n",
" cols, names = list(), list()\n",
" cols.append(df.shift(0))\n",
" names += [('%s' % origNames[j]) for j in range(n_vars)]\n",
" n_in = max(0, n_in)\n",
" for i in range(n_in, 0, -1):\n",
" time = '(t-%d)' % i\n",
" cols.append(df.shift(i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" n_out = max(n_out, 0)\n",
" for i in range(1, n_out+1):\n",
" time = '(t+%d)' % i\n",
" cols.append(df.shift(-i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" agg = pd.concat(cols, axis=1)\n",
" agg.columns = names\n",
" if dropnan:\n",
" agg.dropna(inplace=True)\n",
" return agg"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall Power\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 0.0\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 0.0\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 0.0\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 0.0\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 0.0\n",
"(104256, 6)\n"
]
}
],
"source": [
"# 加载数据\n",
"path1 = r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\data6.csv\"#数据所在路径\n",
"#我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"datas1 = pd.DataFrame(pd.read_csv(path1))\n",
"#我只取了data表里的第3、23、16、17、18、19、20、21、27列如果取全部列的话这一行可以去掉\n",
"# data1 = datas1.iloc[:,np.r_[3,23,16:22,27]]\n",
"data1=datas1.interpolate()\n",
"values1 = data1.values\n",
"print(data1.head())\n",
"print(data1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# data2= data1.drop(['date','Air_P','RH'], axis = 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# # 获取重构的原始数据\n",
"# # 获取重构的原始数据\n",
"# # 获取重构的原始数据\n",
"high_re= r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\完整的模型代码流程\\high_re.csv\"#数据所在路径\n",
"# #我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"high_re = pd.DataFrame(pd.read_csv(high_re))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" column_name\n",
"0 -1.426824\n",
"1 -1.426819\n",
"2 -1.426815\n",
"3 -1.426812\n",
"4 -1.426810\n",
"... ...\n",
"104251 -1.629381\n",
"104252 -1.629328\n",
"104253 -1.629271\n",
"104254 -1.629213\n",
"104255 -1.629152\n",
"\n",
"[104256 rows x 1 columns]\n"
]
}
],
"source": [
"reconstructed_data_high= high_re\n",
"# # 打印重构的原始数据\n",
"print(reconstructed_data_high)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0wAAAIjCAYAAAAwSJuMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAACNU0lEQVR4nO3dd5hTVf4G8DfTe2M6DEPvSBUEpClKE8QCFlTAisIidlEXdBXBgru4rrp27K4KqKAUpYqASAcB6R2GNr1Pzu+P87vJhGnp597k/TzPPDdkMpnvXHKT+552TUIIASIiIiIiIqoiQHUBREREREREesXAREREREREVAMGJiIiIiIiohowMBEREREREdWAgYmIiIiIiKgGDExEREREREQ1YGAiIiIiIiKqAQMTERERERFRDRiYiIiIiIiIasDARER2MZlMePbZZ1WXUa1GjRrhmmuuUV2GX+vXrx/69evn1M+OHTsWjRo1cms9F/voo49gMplw6NAhj/4eR+j5mCLXrVixAiaTCStWrHD5ufT4+q2sX79+aNeuXZ2PO3ToEEwmEz766COHf4f2s6+++qoTFRK5hoGJyE779+/HfffdhyZNmiAsLAwxMTHo1asXZs+ejaKiItXlkRsVFhbi2WefdcuJjt4tWLAAgwYNQr169RAWFoYWLVrg0Ucfxblz51SXRn7gxRdfxPz58/2+Bk/6/PPP8a9//Ut1GUSGFqS6ACIjWLhwIUaOHInQ0FDccccdaNeuHUpLS/Hrr7/isccew86dO/HOO++oLtOjioqKEBTkH28ZhYWFeO655wDA6V4TI3j00Ucxa9YsdOjQAU888QQSEhKwadMmvPHGG/jyyy/xyy+/oGXLlnY915IlS5yu491334XZbHb658m4XnzxRdx4440YMWKEX9fgSZ9//jl27NiByZMnqy4FmZmZKCoqQnBwsOpSiBziH2c/RC44ePAgbr75ZmRmZmLZsmVIS0uzfG/ChAnYt28fFi5cqLBCzzGbzSgtLUVYWBjCwsJUl0Nu9MUXX2DWrFm46aab8NlnnyEwMNDyvbFjx6J///4YOXIkNm3aVGtQLiwsREREBEJCQpyuhSdPjikuLkZISAgCAvxrkEhBQQEiIyNVl0EuMJlM/CwhQ/Kvd1siJ7z88svIz8/H+++/bxOWNM2aNcODDz5o+Xd5eTmef/55NG3aFKGhoWjUqBGeeuoplJSU2PycNu9mxYoV6Nq1K8LDw9G+fXvLMLC5c+eiffv2CAsLQ5cuXbB582abnx87diyioqJw4MABDBw4EJGRkUhPT8c//vEPCCFsHvvqq6+iZ8+eqFevHsLDw9GlSxd88803Vf4Wk8mEiRMn4rPPPkPbtm0RGhqKRYsWWb5Xeb5FXl4eJk+ejEaNGiE0NBTJycm46qqrsGnTJpvn/Prrr9GlSxeEh4cjMTERt912G44fP17t33L8+HGMGDECUVFRSEpKwqOPPoqKiooa/meqWrJkCTp27IiwsDC0adMGc+fOrfKY7OxsTJ48GRkZGQgNDUWzZs3w0ksvWXo4Dh06hKSkJADAc889B5PJZPnbv//+e5hMJmzbts3yfN9++y1MJhOuv/56m9/TunVr3HTTTTb3ffrpp5Z9kZCQgJtvvhlHjx6tUuP69esxaNAgxMbGIiIiAn379sWaNWtsHvPss8/CZDJh3759GDt2LOLi4hAbG4tx48ahsLCwzn313HPPIT4+Hu+8845NWAKAbt264YknnsD27dttXifaPIWNGzeiT58+iIiIwFNPPWX53sW9cYcPH8bw4cMRGRmJ5ORkPPTQQ1i8eHGVeR0Xz2GqPFfhnXfesRxLl156KTZs2GDzO7Zt24axY8dahsqmpqbizjvvdHpIob3P58j+LykpwUMPPYSkpCRER0dj+PDhOHbsmF31aPNgvvzySzzzzDOoX78+IiIikJubC8C+1woAHD9+HHfddRfS09MRGhqKxo0b4/7770dpaanlMQcOHMDIkSORkJCAiIgIXHbZZVUag7R6/ve//2H69Olo0KABwsLCcOWVV2Lfvn02j927dy9uuOEGpKamIiwsDA0aNMDNN9+MnJwcAPI9paCgAHPmzLEcZ2PHjrXZv3/++SduvfVWxMfH4/LLLwdQ83y56ubCmc1mzJ492/JempSUhEGDBuGPP/6oswZtv915551ISUlBaGgo2rZtiw8++KDK7z527BhGjBhh81q/+D3fXjt37sQVV1yB8PBwNGjQAC+88EK1PbDfffcdhg4davk/bdq0KZ5//nmb98x+/fph4cKFOHz4sOXv0/ZRaWkppk6dii5duiA2NhaRkZHo3bs3li9f7lTdAPDnn3+if//+iIiIQP369fHyyy/bfL+mOUxff/012rRpg7CwMLRr1w7z5s2rdW5jXe8LRO7GHiaiOvzwww9o0qQJevbsadfj7777bsyZMwc33ngjHnnkEaxfvx4zZszArl27MG/ePJvH7tu3D7feeivuu+8+3HbbbXj11VcxbNgwvP3223jqqafwwAMPAABmzJiBUaNGYc+ePTatyhUVFRg0aBAuu+wyvPzyy1i0aBGmTZuG8vJy/OMf/7A8bvbs2Rg+fDhGjx6N0tJSfPnllxg5ciQWLFiAoUOH2tS0bNky/O9//8PEiRORmJhY4wfW+PHj8c0332DixIlo06YNzp07h19//RW7du1C586dAciJyuPGjcOll16KGTNm4PTp05g9ezbWrFmDzZs3Iy4uzuZvGThwILp3745XX30VP//8M2bNmoWmTZvi/vvvr3O/7927FzfddBPGjx+PMWPG4MMPP8TIkSOxaNEiXHXVVQBkb0jfvn1x/Phx3HfffWjYsCF+++03TJkyBSdPnsS//vUvJCUl4a233sL999+P6667zhKELrnkEjRo0AAmkwmrVq3CJZdcAgBYvXo1AgIC8Ouvv1pqOXPmDHbv3o2JEyda7ps+fTr+/ve/Y9SoUbj77rtx5swZ/Pvf/0afPn1s9sWyZcswePBgdOnSBdOmTUNAQAA+/PBDXHHFFVi9ejW6detm83ePGjUKjRs3xowZM7Bp0ya89957SE5OxksvvVTrvtqzZw/Gjh2LmJiYah9zxx13YNq0aViwYAFuvvlmy/3nzp3D4MGDcfPNN+O2225DSkpKtT9fUFCAK664AidPnsSDDz6I1NRUfP755w6djH3++efIy8vDfffdB5PJhJdffhnXX389Dhw4YOmVWrp0KQ4cOIBx48YhNTXVMjx2586dWLduHUwmk92/z5nns2f/33333fj0009x6623omfPnli2bFmV464uzz//PEJCQvDoo4+ipKQEISEhdr9WTpw4gW7duiE7Oxv33nsvWrVqhePHj+Obb75BYWEhQkJCcPr0afTs2ROFhYWYNGkS6tWrhzlz5mD48OH45ptvcN1119nUM3PmTAQEBODRRx9FTk4OXn75ZYwePRrr168HIE/GBw4ciJKSEvztb39Damoqjh8/jgULFiA7OxuxsbH45JNPcPfdd6Nbt2649957AQBNmza1+T0jR45E8+bN8eKLL1ZpCLLHXXfdhY8++giDBw/G3XffjfLycqxevRrr1q1D165da63h9OnTuOyyyywNSUlJSfjpp59w1113ITc31zLEraioCFdeeSWOHDmCSZMmIT09HZ988gmWLVvmcL2nTp1C//79UV5ejieffBKRkZF45513EB4eXuWxH330EaKiovDwww8jKioKy5Ytw9SpU5Gbm4tXXnkFAPD0008jJycHx44dwz//+U8AQFRUFAAgNzcX7733Hm655Rbcc889yMvLw/vvv4+BAwfi999/R8eOHR2q/cKFCxg0aBCuv/56jBo1Ct988w2eeOIJtG/fHoMHD67x5xYuXIibbroJ7du3x4wZM3DhwgXcddddqF+/frWPt+d9gcjtBBHVKCcnRwAQ1157rV2P37JliwAg7r77bpv7H330UQFALFu2zHJfZmamACB+++03y32LFy8WAER4eLg4fPiw5f7//ve/AoBYvny55b4xY8YIAOJvf/ub5T6z2SyGDh0qQkJCxJk
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# # 假设你已经有了原始数据和重构数据\n",
"# # 原始数据\n",
"original_data = data1['Power'].values\n",
"\n",
"# # 创建时间序列(假设时间序列与数据对应)\n",
"time = range(len(original_data))\n",
"\n",
"# # 创建画布和子图\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# # 绘制原始数据\n",
"# plt.plot(time, original_data, label='Original Data', color='blue')\n",
"\n",
"# # 绘制重构数据\n",
"plt.plot(reconstructed_data_high[200:1000], label='Reconstructed Data', color='red')\n",
"\n",
"# # 添加标题和标签\n",
"plt.title('Comparison between Original and reconstructed_data_high')\n",
"plt.xlabel('Time')\n",
"plt.ylabel('Power')\n",
"plt.legend()\n",
"\n",
"# # 显示图形\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data3=data1.iloc[:,:5]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall column_name\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 -1.426824\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 -1.426819\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 -1.426815\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 -1.426812\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 -1.426810\n",
"... ... ... ... ... ... ...\n",
"104251 13.303740 34.212711 1.210789 0.787026 0.0 -1.629381\n",
"104252 13.120920 34.394939 2.142980 1.582670 0.0 -1.629328\n",
"104253 12.879215 35.167400 1.926214 1.545889 0.0 -1.629271\n",
"104254 12.915867 35.359989 1.317695 0.851529 0.0 -1.629213\n",
"104255 13.134816 34.500034 1.043269 0.597816 0.0 -1.629152\n",
"\n",
"[104256 rows x 6 columns]\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"# # 创建data3和imf1_array对应的DataFrame\n",
"data3_df = pd.DataFrame(data3)\n",
"imf1_df = pd.DataFrame(reconstructed_data_high)\n",
"\n",
"# # 合并data3_df和imf1_df\n",
"merged_df = pd.concat([data3_df, imf1_df], axis=1)\n",
"\n",
"merged_df = merged_df.iloc[:104256]\n",
"\n",
"# # 打印合并后的表\n",
"print(merged_df)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104256, 6)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"merged_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(104256, 6)\n"
]
}
],
"source": [
"# 使用MinMaxScaler进行归一化\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"scaledData1 = scaler.fit_transform(merged_df)\n",
"print(scaledData1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 1 2 3 4 5 0(t-96) \\\n",
"96 0.555631 0.349673 0.190042 0.040558 0.0 0.245160 0.490360 \n",
"97 0.564819 0.315350 0.211335 0.044613 0.0 0.264683 0.489088 \n",
"98 0.576854 0.288321 0.229657 0.047549 0.0 0.283988 0.485824 \n",
"99 0.581973 0.268243 0.247775 0.053347 0.0 0.303131 0.482997 \n",
"100 0.586026 0.264586 0.266058 0.057351 0.0 0.322308 0.482632 \n",
"\n",
" 1(t-96) 2(t-96) 3(t-96) ... 2(t-1) 3(t-1) 4(t-1) 5(t-1) \\\n",
"96 0.369105 0.002088 0.002013 ... 0.166009 0.036794 0.0 0.225396 \n",
"97 0.364859 0.002061 0.001839 ... 0.190042 0.040558 0.0 0.245160 \n",
"98 0.364886 0.001973 0.001839 ... 0.211335 0.044613 0.0 0.264683 \n",
"99 0.365615 0.001950 0.001697 ... 0.229657 0.047549 0.0 0.283988 \n",
"100 0.361965 0.001679 0.001167 ... 0.247775 0.053347 0.0 0.303131 \n",
"\n",
" 0(t+1) 1(t+1) 2(t+1) 3(t+1) 4(t+1) 5(t+1) \n",
"96 0.564819 0.315350 0.211335 0.044613 0.0 0.264683 \n",
"97 0.576854 0.288321 0.229657 0.047549 0.0 0.283988 \n",
"98 0.581973 0.268243 0.247775 0.053347 0.0 0.303131 \n",
"99 0.586026 0.264586 0.266058 0.057351 0.0 0.322308 \n",
"100 0.590772 0.258790 0.282900 0.060958 0.0 0.340588 \n",
"\n",
"[5 rows x 588 columns]\n"
]
}
],
"source": [
"n_steps_in =96 #历史时间长度\n",
"n_steps_out=1#预测时间长度\n",
"processedData1 = time_series_to_supervised(scaledData1,n_steps_in,n_steps_out)\n",
"print(processedData1.head())"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# processedData1.to_csv('processedData1.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"data_x = processedData1.loc[:,'0(t-96)':'5(t-1)']\n",
"data_y = processedData1.loc[:,'5']"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159, 576)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_x.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"96 0.245160\n",
"97 0.264683\n",
"98 0.283988\n",
"99 0.303131\n",
"100 0.322308\n",
" ... \n",
"104250 0.000090\n",
"104251 0.000099\n",
"104252 0.000109\n",
"104253 0.000118\n",
"104254 0.000128\n",
"Name: 5, Length: 104159, dtype: float64"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104159,)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(83328, 96, 6) (83328,) (20831, 96, 6) (20831,)\n"
]
}
],
"source": [
"# 7.划分训练集和测试集\n",
"\n",
"test_size = int(len(data_x) * 0.2)\n",
"# 计算训练集和测试集的索引范围\n",
"train_indices = range(len(data_x) - test_size)\n",
"test_indices = range(len(data_x) - test_size, len(data_x))\n",
"\n",
"# 根据索引范围划分数据集\n",
"train_X1 = data_x.iloc[train_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"test_X1 = data_x.iloc[test_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"train_y = data_y.iloc[train_indices].values\n",
"test_y = data_y.iloc[test_indices].values\n",
"\n",
"\n",
"# # 多次运行代码时希望得到相同的数据分割,可以设置 random_state 参数为一个固定的整数值\n",
"# train_X1,test_X1, train_y, test_y = train_test_split(data_x.values, data_y.values, test_size=0.2, random_state=343)\n",
"# reshape input to be 3D [samples, timesteps, features]\n",
"train_X = train_X1.reshape((train_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"test_X = test_X1.reshape((test_X1.shape[0], n_steps_in,scaledData1.shape[1]))\n",
"print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)\n",
"# 使用train_test_split函数划分训练集和测试集测试集的比重是40%。\n",
"# 然后将train_X1、test_X1进行一个升维变成三维维数分别是[samples,timesteps,features]。\n",
"# 打印一下他们的shape\\\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(83328, 96, 6)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_X1.shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From d:\\Anaconda3\\lib\\site-packages\\keras\\src\\backend\\tensorflow\\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">6</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv1D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">832</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling1D</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">49,920</span> │ max_pooling1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,304</span> │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AttentionWithImpr…</span> │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)] │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ attention_with_i… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePool…</span> │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">129</span> │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m832\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling1D\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,920\u001b[0m │ max_pooling1d[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m66,304\u001b[0m │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mAttentionWithImpr…\u001b[0m │ \u001b[38;5;34m128\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention_with_i… │\n",
"│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Conv1D, Bidirectional, GlobalAveragePooling1D, Dense, GRU, MaxPooling1D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.initializers import RandomUniform\n",
"class AttentionWithImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, max_len=5000):\n",
" super(AttentionWithImproveRelativePositionEncoding, self).__init__()\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
" self.max_len = max_len\n",
" self.wq = tf.keras.layers.Dense(d_model)\n",
" self.wk = tf.keras.layers.Dense(d_model)\n",
" self.wv = tf.keras.layers.Dense(d_model)\n",
" self.dense = tf.keras.layers.Dense(d_model)\n",
" self.position_encoding = ImproveRelativePositionEncoding(d_model)\n",
"\n",
" def call(self, v, k, q, mask):\n",
" batch_size = tf.shape(q)[0]\n",
" q = self.wq(q)\n",
" k = self.wk(k)\n",
" v = self.wv(v)\n",
"\n",
" # 添加位置编码\n",
" k += self.position_encoding (k)\n",
" q += self.position_encoding (q)\n",
"\n",
" q = self.split_heads(q, batch_size)\n",
" k = self.split_heads(k, batch_size)\n",
" v = self.split_heads(v, batch_size)\n",
"\n",
" scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)\n",
" scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])\n",
" concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))\n",
" output = self.dense(concat_attention)\n",
" return output, attention_weights\n",
"\n",
" def split_heads(self, x, batch_size):\n",
" x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_model // self.num_heads))\n",
" return tf.transpose(x, perm=[0, 2, 1, 3])\n",
"\n",
" def scaled_dot_product_attention(self, q, k, v, mask):\n",
" matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
" dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
" scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
"\n",
" if mask is not None:\n",
" scaled_attention_logits += (mask * -1e9)\n",
"\n",
" attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
" output = tf.matmul(attention_weights, v)\n",
" return output, attention_weights\n",
"\n",
"class ImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, max_len=5000):\n",
" super(ImproveRelativePositionEncoding, self).__init__()\n",
" self.max_len = max_len\n",
" self.d_model = d_model\n",
" # 引入可变化的参数u和v进行线性变化\n",
" self.u = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" self.v = self.add_weight(shape=(self.d_model,),\n",
" initializer=RandomUniform(),\n",
" trainable=True)\n",
" def call(self, inputs):\n",
" seq_length = inputs.shape[1]\n",
" pos_encoding = self.relative_positional_encoding(seq_length, self.d_model)\n",
" \n",
" # 调整原始的相对位置编码公式将u和v参数融入其中\n",
" pe_with_params = pos_encoding * self.u+ pos_encoding * self.v\n",
" return inputs + pe_with_params\n",
"\n",
" def relative_positional_encoding(self, position, d_model):\n",
" pos = tf.range(position, dtype=tf.float32)\n",
" i = tf.range(d_model, dtype=tf.float32)\n",
" \n",
" angles = 1 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(d_model, tf.float32))\n",
" angle_rads = tf.einsum('i,j->ij', pos, angles)\n",
" #保留了sinous机制\n",
" # Apply sin to even indices; 2i\n",
" angle_rads_sin = tf.sin(angle_rads[:, 0::2])\n",
" # Apply cos to odd indices; 2i+1\n",
" angle_rads_cos = tf.cos(angle_rads[:, 1::2])\n",
"\n",
" pos_encoding = tf.stack([angle_rads_sin, angle_rads_cos], axis=2)\n",
" pos_encoding = tf.reshape(pos_encoding, [1, position, d_model])\n",
"\n",
" return pos_encoding\n",
"\n",
"\n",
"\n",
"def PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads):\n",
" inputs = Input(shape=input_shape)\n",
" # CNN layer\n",
" cnn_layer = Conv1D(filters=64, kernel_size=2, activation='relu')(inputs)\n",
" cnn_layer = MaxPooling1D(pool_size=1)(cnn_layer)\n",
" gru_output = Bidirectional(GRU(gru_units, return_sequences=True))(cnn_layer)\n",
" \n",
" # Apply Self-Attention\n",
" self_attention =AttentionWithImproveRelativePositionEncoding(d_model=gru_units*2, num_heads=num_heads)\n",
" gru_output, _ = self_attention(gru_output, gru_output, gru_output, mask=None)\n",
" \n",
" pool1 = GlobalAveragePooling1D()(gru_output)\n",
" output = Dense(1)(pool1)\n",
" \n",
" return Model(inputs=inputs, outputs=output)\n",
"\n",
"\n",
"input_shape = (96, 6)\n",
"gru_units = 64\n",
"num_heads = 8\n",
"\n",
"# Create model\n",
"model = PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads)\n",
"model.compile(optimizer='adam', loss='mse')\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m88s\u001b[0m 65ms/step - loss: 0.0178 - val_loss: 0.0018\n",
"Epoch 2/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 64ms/step - loss: 0.0011 - val_loss: 0.0016\n",
"Epoch 3/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m93s\u001b[0m 71ms/step - loss: 0.0010 - val_loss: 0.0024\n",
"Epoch 4/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m84s\u001b[0m 64ms/step - loss: 9.7998e-04 - val_loss: 0.0015\n",
"Epoch 5/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m76s\u001b[0m 59ms/step - loss: 0.0010 - val_loss: 0.0015\n",
"Epoch 6/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m70s\u001b[0m 54ms/step - loss: 0.0010 - val_loss: 0.0016\n",
"Epoch 7/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m76s\u001b[0m 58ms/step - loss: 9.6638e-04 - val_loss: 0.0015\n",
"Epoch 8/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m75s\u001b[0m 58ms/step - loss: 8.8641e-04 - val_loss: 0.0017\n",
"Epoch 9/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m65s\u001b[0m 50ms/step - loss: 9.5932e-04 - val_loss: 0.0015\n",
"Epoch 10/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m67s\u001b[0m 51ms/step - loss: 9.3643e-04 - val_loss: 0.0015\n",
"Epoch 11/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m67s\u001b[0m 52ms/step - loss: 9.2035e-04 - val_loss: 0.0017\n",
"Epoch 12/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m68s\u001b[0m 52ms/step - loss: 8.8128e-04 - val_loss: 0.0017\n",
"Epoch 13/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m74s\u001b[0m 57ms/step - loss: 8.7290e-04 - val_loss: 0.0016\n",
"Epoch 14/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m77s\u001b[0m 59ms/step - loss: 8.5652e-04 - val_loss: 0.0016\n",
"Epoch 15/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m81s\u001b[0m 62ms/step - loss: 8.6573e-04 - val_loss: 0.0018\n",
"Epoch 16/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 61ms/step - loss: 9.3113e-04 - val_loss: 0.0015\n",
"Epoch 17/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 61ms/step - loss: 8.6217e-04 - val_loss: 0.0015\n",
"\u001b[1m651/651\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 17ms/step\n"
]
}
],
"source": [
"# Compile and train the model\n",
"model.compile(optimizer='adam', loss='mean_squared_error')\n",
"from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
"\n",
"# 定义早停机制\n",
"early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')\n",
"\n",
"# 拟合模型,并添加早停机制和模型检查点\n",
"history = model.fit(train_X, train_y, epochs=100, batch_size=64, validation_data=(test_X, test_y), \n",
" callbacks=[early_stopping])\n",
"# 预测\n",
"lstm_pred = model.predict(test_X)\n",
"# 将预测结果的形状修改为与原始数据相同的形状"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABN8klEQVR4nO3de3xT9f0/8NdJmkubXqBJ6QWK7aCIQLlD5bKhs7NOplanIlO5zK+6/VCpFVQYF6doFUW5zo5t3rYx0E3ReUFZVbxQi1Au4oWbIAi0tECbNr2kTc7vj08uDbSlaZOcJH09H488kpx8cvo5RZNX35/P+RxJlmUZRERERCFOpXQHiIiIiHyBoYaIiIjCAkMNERERhQWGGiIiIgoLDDVEREQUFhhqiIiIKCww1BAREVFYYKghIiKisBChdAcCxW6348SJE4iJiYEkSUp3h4iIiDpAlmXU1NQgJSUFKlX7tZhuE2pOnDiB1NRUpbtBREREnXDs2DH06dOn3TbdJtTExMQAEL+U2NhYhXtDREREHWE2m5Gamur6Hm9Ptwk1ziGn2NhYhhoiIqIQ05GpI5woTERERGGBoYaIiIjCAkMNERERhYVuM6eGiIjIX2RZRnNzM2w2m9JdCTlqtRoRERE+WW6FoYaIiKgLrFYrTp48ibq6OqW7ErKioqKQnJwMrVbbpf0w1BAREXWS3W7H4cOHoVarkZKSAq1WywVevSDLMqxWKyoqKnD48GFkZGRccIG99jDUEBERdZLVaoXdbkdqaiqioqKU7k5IioyMhEajwQ8//ACr1Qq9Xt/pfXGiMBERURd1pbpAvvv98V+BiIiIwgJDDREREYUFhhoiIiLqkrS0NCxfvlzpbnCiMBERUXd02WWXYfjw4T4JI19++SUMBkPXO9VFDDVdtL+8Bq9tP4Z4gw6/v6yf0t0hIiLyCVmWYbPZEBFx4aiQkJAQgB5dGIefuuhkdQP+8ulhvLnruNJdISIihcmyjDprsyI3WZY73M8ZM2Zgy5YtWLFiBSRJgiRJeOmllyBJEt577z2MGjUKOp0On332GQ4dOoTrrrsOiYmJiI6OxpgxY/C///3PY3/nDj9JkoS//vWvuP766xEVFYWMjAy89dZbvvo1t4mVmi4yRYvVD09brAr3hIiIlFbfZMOgRe8r8rO/eTQHUdqOfa2vWLEC+/fvx5AhQ/Doo48CAL7++msAwMMPP4xnnnkGP/nJT9CzZ08cO3YMV199NR5//HHodDq88soruOaaa7Bv3z707du3zZ/xxz/+EUuXLsXTTz+NVatW4dZbb8UPP/yA+Pj4rh9sG1ip6SJTtA4AcMZihd3e8ZRMRESklLi4OGi1WkRFRSEpKQlJSUlQq9UAgEcffRS/+MUv0K9fP8THx2PYsGG4++67MWTIEGRkZOCxxx5Dv379Llh5mTFjBqZOnYr+/fvjiSeeQG1tLbZt2+bX42KlpoviDaJSY7PLqKpvcj0nIqLuJ1KjxjeP5ij2s31h9OjRHs9ra2vxyCOP4J133sHJkyfR3NyM+vp6HD16tN39DB061PXYYDAgNjYWp06d8kkf28JQ00UatQo9ojSoqmtCZW0jQw0RUTcmSVKHh4CC1blnMc2ZMwebN2/GM888g/79+yMyMhI33ngjrNb2p11oNBqP55IkwW63+7y/LYX2bz5IGA1aV6gZkBijdHeIiIguSKvVwmazXbDd559/jhkzZuD6668HICo3R44c8XPvOodzanzAOa+mspaThYmIKDSkpaWhpKQER44cQWVlZZtVlIyMDLz++uvYtWsXdu/ejd/85jd+r7h0FkONDzhDzenaRoV7QkRE1DFz5syBWq3GoEGDkJCQ0OYcmWeffRY9e/bE+PHjcc011yAnJwcjR44McG87hsNPPuA8rbuSoYaIiELEgAEDUFxc7LFtxowZ57VLS0vDhx9+6LFt1qxZHs/PHY5qbc2cqqqqTvXTG6zU+IDRVanh8BMREZFSGGp8wD2nhpUaIiIipTDU+IDRNfzESg0REZFSGGp8gJUaIiIi5THU+IDr+k+s1BARESmGocYHnJWa+iYbLI3NCveGiIioe2Ko8YEorRp6jfhVslpDRESkDIYaH5AkyVWtqeC8GiIiIkUw1PiIkasKExERKYqhxkcSeFo3ERGFkMsuuwx5eXk+29+MGTOQm5vrs/11BkONjxgNrNQQEREpiaHGR0wxjtO6LazUEBFRcJsxYwa2bNmCFStWQJIkSJKEI0eOYO/evfjlL3+J6OhoJCYm4vbbb0dlZaXrff/+97+RmZmJyMhIGI1GZGdnw2Kx4JFHHsHLL7+MN99807W/jz/+OODHxQta+oizUsOJwkRE3ZgsA011yvxsTRQgSR1qumLFCuzfvx9DhgzBo48+Kt6u0WDs2LH4v//7Pzz33HOor6/HQw89hJtvvhkffvghTp48ialTp2Lp0qW4/vrrUVNTg08//RSyLGPOnDn49ttvYTab8eKLLwIA4uPj/XaobWGo8RFTDIefiIi6vaY64IkUZX72/BOA1tChpnFxcdBqtYiKikJSUhIAYMmSJRgxYgSeeOIJV7sXXngBqamp2L9/P2pra9Hc3IwbbrgBF110EQAgMzPT1TYyMhKNjY2u/SmBocZHTAZOFCYiotC1e/dufPTRR4iOjj7vtUOHDuHKK6/EFVdcgczMTOTk5ODKK6/EjTfeiJ49eyrQ29Yx1PgIKzVERARNlKiYKPWzu6C2thbXXHMNnnrqqfNeS05OhlqtxubNm7F161Z88MEHWLVqFf7whz+gpKQE6enpXfrZvsJQ4yNGR6XmbF0Tmmx2aNScg01E1O1IUoeHgJSm1Wphs9lcz0eOHIn//Oc/SEtLQ0RE6/FAkiRMmDABEyZMwKJFi3DRRRfhjTfeQH5+/nn7UwK/eX2kZ5QWKsf8rLM8A4qIiIJcWloaSkpKcOTIEVRWVmLWrFk4c+YMpk6dii+//BKHDh3C+++/j5kzZ8Jms6GkpARPPPEEtm/fjqNHj+L1119HRUUFLrnkEtf+9uzZg3379qGyshJNTU0BPyaGGh9RqSTE8wwoIiIKEXPmzIFarcagQYOQkJAAq9WKzz//HDabDVdeeSUyMzORl5eHHj16QKVSITY2Fp988gmuvvpqDBgwAAsWLMCyZcvwy1/+EgBw55134uKLL8bo0aORkJCAzz//PODHxOEnHzJFa1FZ28iLWhIRUdAbMGAAiouLz9v++uuvt9r+kksuwaZNm9rcX0JCAj744AOf9a8zWKnxIedFLStZqSEiIgo4hhofMjmu/8RKDRERUeAx1PiQkZUaIiIixTDU+JB7+ImVGiIiokBjqPEhY7RzVWFWaoiIiAKNocaHEhyVmtMWhhoiou5ElmWluxDSfPX7Y6jxIVelpobDT0RE3YFGowEA1NUpdGXuMOH8/Tl/n53FdWp8yNSiUiPLMqQOXgKeiIhCk1qtRo8ePXDq1CkAQFRUFD/7vSDLMurq6nDq1Cn06NEDarW6S/tjqPGheMf1n5psMsz1zYiL6lriJCKi4JeUlAQArmBD3uvRo4fr99gVDDU+pNeoEaOPQE1DMyotjQw1RETdgCRJSE5ORq9evRS53lGo02g0Xa7QODHU+JgpWidCTU0j+iVEK90dIiIKELVa7bMvZ+qcTk0UXrNmDdLS0qDX65GVlYVt27a12/61117DwIEDodfrkZmZiXfffdfjdVmWsWjRIiQnJyMyMhLZ2dk4cOCAR5u0tDRIkuRxe/LJJzvTfb9yrSrMK3UTEREFlNehZsOGDcjPz8fixYtRWlqKYcOGIScnp82xxK1bt2Lq1Km44447sHPnTuTm5iI3Nxd79+51tVm6dClWrlyJwsJClJSUwGAwICcnBw0NDR77evTRR3Hy5EnX7d577/W2+35nNHBVYSIiIiV4HWqeffZZ3HnnnZg5cyYGDRqEwsJCREVF4YUXXmi1/YoVK3DVVVdh7ty5uOSSS/DYY49h5MiRWL16NQBRpVm+fDkWLFiA6667DkOHDsUrr7yCEydOYOPGjR77iomJQVJSkutmMBi8P2I/M8U
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'], label='train')\n",
"plt.plot(history.history['val_loss'], label='test')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20831, 1)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lstm_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20831,)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"test_y1=test_y.reshape(20831,1)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.52189913e-01],\n",
" [3.12516873e-01],\n",
" [3.25310588e-01],\n",
" ...,\n",
" [1.08522631e-04],\n",
" [1.18219088e-04],\n",
" [1.28327022e-04]])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y1"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"results1 = np.broadcast_to(lstm_pred, (20831, 6))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"test_y2 = np.broadcast_to(test_y1, (20831, 6))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# 反归一化\n",
"inv_forecast_y = scaler.inverse_transform(results1)\n",
"inv_test_y = scaler.inverse_transform(test_y2)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.78428369e+01, 4.82409691e+01, 6.37156385e+02,\n",
" 2.97801603e+02, 1.07621239e+01, 9.90052500e-01],\n",
" [ 1.07562527e+01, 3.44305945e+01, 4.40440713e+02,\n",
" 2.05929459e+02, 7.43790432e+00, 1.80780551e-01],\n",
" [ 1.14053667e+01, 3.56955916e+01, 4.58459395e+02,\n",
" 2.14344726e+02, 7.74239484e+00, 2.54907916e-01],\n",
" ...,\n",
" [-5.09439462e+00, 3.54076535e+00, 4.44428011e-01,\n",
" 4.37940726e-01, 2.58283957e-03, -1.62932764e+00],\n",
" [-5.09390265e+00, 3.54172410e+00, 4.58084512e-01,\n",
" 4.44318723e-01, 2.81361533e-03, -1.62927146e+00],\n",
" [-5.09338980e+00, 3.54272354e+00, 4.72320538e-01,\n",
" 4.50967376e-01, 3.05418424e-03, -1.62921289e+00]])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_test_y"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test RMSE: 0.223\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQoAAAKTCAYAAABRkzVdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d9w0V13+f52Z2XrXpyZPeqUETIAQIHRRiiBF0a8iiqiABaSIPxX98lVBCYqCIBBBwKggIkhR6TUYIARCgIT0Xp48/bnLtqnn98fM7s7MzuzO3jtld+/r/XqFZ3d2duZwn52Zc65zfT4fIaWUIIQQQgghhBBCCCGEbGuUohtACCGEEEIIIYQQQggpHgqFhBBCCCGEEEIIIYQQCoWEEEIIIYQQQgghhBAKhYQQQgghhBBCCCGEEFAoJIQQQgghhBBCCCGEgEIhIYQQQgghhBBCCCEEFAoJIYQQQgghhBBCCCEAtKIbMAzHcbB//34sLS1BCFF0cwghhBBCCCGEEEIImSmklNjc3MRJJ50ERRnuGZxqoXD//v049dRTi24GIYQQQgghhBBCCCEzzT333INTTjll6D5TLRQuLS0BcP+PLC8vF9waQgghhBBCCCGEEEJmi42NDZx66qk9nW0YUy0UdsONl5eXKRQSQgghhBBCCCGEELJFkqT1YzETQgghhBBCCCGEEEIIhUJCCCGEEEIIIYQQQgiFQkIIIYQQQgghhBBCCKY8RyEhhBBCCCGEEEII2Z7Ytg3TNItuxkxQLpehKJP7ASkUEkIIIYQQQgghhJCpQUqJAwcOYG1treimzAyKouDMM89EuVye6DgUCgkhhBBCCCGEEELI1NAVCffu3Yt6vZ6oWu92xnEc7N+/H/fffz9OO+20if5eFAoJIYQQQgghhBBCyFRg23ZPJNy1a1fRzZkZ9uzZg/3798OyLJRKpS0fh8VMCCGEEEIIIYQQQshU0M1JWK/XC27JbNENObZte6LjUCgkhBBCCCGEEEIIIVMFw43HI62/F4VCQgghhBBCCCGEEEIIhUJCCCGEEEIIIYQQQgiFQkIIIYQQQgghhBBCCCgUEkIIIYQQQgghhBAyMU9+8pPx6le/uuhmTASFQkIIIYQQQgghhBBCMkZKCcuyim7GUCgUEkIIIYQQQgghhJCpREqJlmEV8p+UMnE7X/ziF+Pyyy/H29/+dgghIITAZZddBiEEPvvZz+LCCy9EpVLBFVdcgRe/+MV43vOeF/j+q1/9ajz5yU/uvXccB5dccgnOPPNM1Go1XHDBBfjYxz6W0l81Hi3zMxBCCCGEEEIIIYQQsgXapo3z/t/nCzn39W94OurlZNLZ29/+dtx888146EMfije84Q0AgB/96EcAgD/6oz/C3/zN3+Css87Cjh07Eh3vkksuwQc/+EH8wz/8A84991x8/etfxy//8i9jz549eNKTnrS1/0MJoFBICCGEEEIIIYQQQsgErKysoFwuo16v48QTTwQA3HjjjQCAN7zhDXjqU5+a+Fi6ruNNb3oTvvSlL+Hiiy8GAJx11lm44oor8J73vIdCISGEEEIIIYQQQgjZftRKKq5/w9MLO3caPPKRjxxr/1tvvRWtVmtAXDQMAw9/+MNTaVMcFAoJIYQQQgghhBBCyFQihEgc/jutLCwsBN4rijKQ/9A0zd7rRqMBAPj0pz+Nk08+ObBfpVLJqJUus/2XJoQQQgghhBBCCCFkCiiXy7Bte+R+e/bswXXXXRfY9v3vfx+lUgkAcN5556FSqeDuu+/ONMw4CgqFhBBCCCGEEEIIIYRMyBlnnIFvf/vbuPPOO7G4uAjHcSL3e8pTnoK3vOUt+Jd/+RdcfPHF+OAHP4jrrruuF1a8tLSE3//938drXvMaOI6Dxz/+8VhfX8c3vvENLC8v41d/9Vcz+/+gZHZkQgghhBBCCCGEEEK2Cb//+78PVVVx3nnnYc+ePbj77rsj93v605+O17/+9fiDP/gDXHTRRdjc3MSLXvSiwD5vfOMb8frXvx6XXHIJHvzgB+MZz3gGPv3pT+PMM8/M9P+DkOGg6CliY2MDKysrWF9fx/LyctHNIYQQQgghhBBCCCEZ0ul0cMcdd+DMM89EtVotujkzw7C/2zj6Gh2FhBBCCCGEEEIIIYQQCoWEEEIIIYSQybnnWAsv/9D38P171opuCiGEEEK2CIVCQgghhBBCyMS8+iPfx6evvR/Pe9c3En/n1kObeNm/fBfX3beeYcsIIYQQkhQKhYQQQgghhJCJuetoc+zv/PL7rsIXrj+I51/6zQxaRAghhJBxoVBICCGEEEIImZiSOv7U4sBGBwCgW07azSGEEELIFqBQSAghhBBCCJkYTRVFN4EQQgghE0KhkBBCCCGEEDIxJYVTC0IIIWTW4dOcEEIIIYQQMjFbCT0mhBBCyHTBpzkhhBBCZoJ/+sYdeP0nr4OUsuimEEIiYOjxNucHHwFu+O+iW0EIIWRCtKIbQAghhBCShD//7+sBAD99/j48+qxdBbeGEBJGo6Nw+7J5EPjEy9zX/+84wDB0QgiZWXgHJ4QQQsj08fk/AT7084AzWAl1o2MV0CBCyCjKdBRuXzpr/deSFawJISQJhmEU3YRIKBQSQgghZPr41juBW74A3P3NgY8sm5NQQqYR5ijczvhFYqaHIIRsT5785CfjFa94BV7xildgZWUFu3fvxutf//pe2pwzzjgDb3zjG/GiF70Iy8vLeNnLXCf2FVdcgSc84Qmo1Wo49dRT8cpXvhLNZrOw/x98mhNCCCFkerH0wU0OJ6GETCN+oZC5RLcx7HtCSNpICRjNYv4b8572z//8z9A0DVdddRXe/va3461vfSve97739T7/m7/5G1xwwQW45ppr8PrXvx633XYbnvGMZ+D5z38+fvjDH+IjH/kIrrjiCrziFa9I+6+YGOYoJIQQQsj0IiWuvXcd7/3f23ubbAqFhEwlJV/osWE7qGhqga0hhcHQY0JI2pgt4E0nFXPuP94PlBcS737qqafibW97G4QQeOADH4hrr70Wb3vb2/DSl74UAPCUpzwFr33ta3v7v+QlL8ELX/hCvPrVrwYAnHvuuXjHO96BJz3pSbj00ktRrVZT/b+TBDoKCSGEEDLFSDz7nVfgv3+wv7eFjkJCphO/6SIivSjZNvAeTQjZvjzmMY+BEP2Fs4svvhi33HILbNsGADzykY8M7P+DH/wAl112GRYXF3v/Pf3pT4fjOLjjjjtybXsXOgoJIYQQMr1EOFOYo5CQ6cRmyOn2xTcppqOQEJI6pbrr7Cvq3CmysBB0JzYaDfzmb/4mXvnKVw7se9ppp6V67qRQKCSEEELI9BIlFNJRSMhU4k8LIOkq275QMCaEpI0QY4X/Fsm3v/3twPsrr7wS5557LlQ1Oh3HIx7xCFx//fU455xz8mheIhh6TAghhJDpJUIodDgJJWQqCQiFvEy3Mex8Qsj25e6778bv/d7v4aabbsKHP/xh/P3f/z1e9apXxe7/h3/4h/jmN7+JV7ziFfj+97+PW265BZ/61KdYzIQQQgghJArbcRBe1zRtTkIJmUaCjkKyvWDoMSGEAMCLXvQitNttPOpRj4KqqnjVq16Fl73sZbH7n3/++bj88svxJ3/yJ3jCE54AKSXOPvts/MIv/EKOrQ5CoZAQQgghU8utBzcArAa22aySQMhU4nf7SloKty/se0LINqZUKuHv/u7vcOmllw58duedd0Z+56KLLsIXvvCFjFuWHIYeE0IIIWR6YY5CQmaGSRyFihi9D5liWMyEEELmBgqFhBBCZpvOOvDhXwJ+9ImiW0IyQETIDTcd2MQz3/6/+Nx1BwpoESEkjkmyAiiCSiEhhBAyDTD0mBBCyGzz9bcAN33a/e8hP1N0a0jKRK1ofur7+wEAv/XBq3Hnm5+Vb4MIIbE4ExQzoVA4R9BRSAjZpnzta18rugmpQEchIYSQ2aZ1rOgWkAwRcKAyJpGQmcAfejxu7DF1wjmCOQoJIWSmoVBICCFkxuHscp5R4KCkso8JmQUCxUzGVArpKJxtvneXb9GOjkJCSEqwMNZ4pPX3olB
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 计算均方根误差\n",
"rmse = sqrt(mean_squared_error(inv_test_y[:,5], inv_forecast_y[:,5]))\n",
"print('Test RMSE: %.3f' % rmse)\n",
"#画图\n",
"plt.figure(figsize=(16,8))\n",
"plt.plot(inv_test_y[300:3000,5], label='true')\n",
"plt.plot(inv_forecast_y[300:3000,5], label='pre')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_squared_error: 0.0014791752952266549\n",
"mean_absolute_error: 0.013799955472387545\n",
"rmse: 0.0384600480398381\n",
"r2 score: 0.9904178817149276\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error # 评价指标\n",
"# 使用sklearn调用衡量线性回归的MSE 、 RMSE、 MAE、r2\n",
"from math import sqrt\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import r2_score\n",
"print('mean_squared_error:', mean_squared_error(lstm_pred, test_y)) # mse)\n",
"print(\"mean_absolute_error:\", mean_absolute_error(lstm_pred, test_y)) # mae\n",
"print(\"rmse:\", sqrt(mean_squared_error(lstm_pred,test_y)))\n",
"print(\"r2 score:\", r2_score(inv_test_y[5000:10000], inv_forecast_y[5000:10000]))"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame(inv_test_y[:,5], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df1.to_csv('高频re_test.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"df2 = pd.DataFrame(inv_forecast_y[:,5], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df2.to_csv('高频re_forecast.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}