ICEEMDAN-Solar_power-forecast/iceemdan-筛选-high-ConvBiGruA...

1100 lines
274 KiB
Plaintext
Raw Normal View History

2024-11-21 13:54:50 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\computation\\expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.8.3' currently installed).\n",
" from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
"C:\\Users\\asus\\AppData\\Roaming\\Python\\Python39\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
" from pandas.core import (\n"
]
}
],
"source": [
"from math import sqrt\n",
"from numpy import concatenate\n",
"from matplotlib import pyplot\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error\n",
"from tensorflow.keras import Sequential\n",
"\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import LSTM\n",
"from tensorflow.keras.layers import Dropout\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这段代码是一个函数 time_series_to_supervised它用于将时间序列数据转换为监督学习问题的数据集。下面是该函数的各个部分的含义\n",
"\n",
"data: 输入的时间序列数据可以是列表或2D NumPy数组。\n",
"n_in: 作为输入的滞后观察数即用多少个时间步的观察值作为输入。默认值为96表示使用前96个时间步的观察值作为输入。\n",
"n_out: 作为输出的观测数量即预测多少个时间步的观察值。默认值为10表示预测未来10个时间步的观察值。\n",
"dropnan: 布尔值表示是否删除具有NaN值的行。默认为True即删除具有NaN值的行。\n",
"函数首先检查输入数据的维度并初始化一些变量。然后它创建一个新的DataFrame对象 df 来存储输入数据,并保存原始的列名。接着,它创建了两个空列表 cols 和 names用于存储新的特征列和列名。\n",
"\n",
"接下来,函数开始构建特征列和对应的列名。首先,它将原始的观察序列添加到 cols 列表中,并将其列名添加到 names 列表中。然后,它依次将滞后的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t-滞后时间)。这样就创建了输入特征的部分。\n",
"\n",
"接着,函数开始构建输出特征的部分。它依次将未来的观察序列添加到 cols 列表中,并构建相应的列名,格式为 (原始列名)(t+未来时间)。\n",
"\n",
"最后函数将所有的特征列拼接在一起构成一个新的DataFrame对象 agg。如果 dropnan 参数为True则删除具有NaN值的行。最后函数返回处理后的数据集 agg。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def time_series_to_supervised(data, n_in=96, n_out=10,dropnan=True):\n",
" \"\"\"\n",
" :param data:作为列表或2D NumPy数组的观察序列。需要。\n",
" :param n_in:作为输入的滞后观察数X。值可以在[1..len数据]之间可选。默认为1。\n",
" :param n_out:作为输出的观测数量y。值可以在[0..len数据]之间。可选的。默认为1。\n",
" :param dropnan:Boolean是否删除具有NaN值的行。可选的。默认为True。\n",
" :return:\n",
" \"\"\"\n",
" n_vars = 1 if type(data) is list else data.shape[1]\n",
" df = pd.DataFrame(data)\n",
" origNames = df.columns\n",
" cols, names = list(), list()\n",
" cols.append(df.shift(0))\n",
" names += [('%s' % origNames[j]) for j in range(n_vars)]\n",
" n_in = max(0, n_in)\n",
" for i in range(n_in, 0, -1):\n",
" time = '(t-%d)' % i\n",
" cols.append(df.shift(i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" n_out = max(n_out, 0)\n",
" for i in range(1, n_out+1):\n",
" time = '(t+%d)' % i\n",
" cols.append(df.shift(-i))\n",
" names += [('%s%s' % (origNames[j], time)) for j in range(n_vars)]\n",
" agg = pd.concat(cols, axis=1)\n",
" agg.columns = names\n",
" if dropnan:\n",
" agg.dropna(inplace=True)\n",
" return agg"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall Power\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 0.0\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 0.0\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 0.0\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 0.0\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 0.0\n",
"(104256, 6)\n"
]
}
],
"source": [
"# 加载数据\n",
"path1 = r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\完整的模型代码流程 copy\\data66.csv\"#数据所在路径\n",
"#我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"datas1 = pd.DataFrame(pd.read_csv(path1))\n",
"#我只取了data表里的第3、23、16、17、18、19、20、21、27列如果取全部列的话这一行可以去掉\n",
"# data1 = datas1.iloc[:,np.r_[3,23,16:22,27]]\n",
"data1=datas1.interpolate()\n",
"values1 = data1.values\n",
"print(data1.head())\n",
"print(data1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# # 获取重构的原始数据\n",
"# # 获取重构的原始数据\n",
"# # 获取重构的原始数据\n",
"high_re= r\"D:\\project\\小论文1-基于ICEEMDAN分解的时序高维变化的短期光伏功率预测模型\\CEEMAN-PosConv1dbiLSTM-LSTM\\模型代码流程\\完整的模型代码流程 copy\\t+3\\iceemdan_reconstructed_data_re_high.csv\"#数据所在路径\n",
"# #我的数据是excel表若是csv文件用pandas的read_csv()函数替换即可。\n",
"high_re = pd.DataFrame(pd.read_csv(high_re))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" column_name\n",
"0 -1.460307\n",
"1 -1.460504\n",
"2 -1.460698\n",
"3 -1.460886\n",
"4 -1.461071\n",
"... ...\n",
"104251 -1.663370\n",
"104252 -1.664516\n",
"104253 -1.665650\n",
"104254 -1.666774\n",
"104255 -1.667887\n",
"\n",
"[104256 rows x 1 columns]\n"
]
}
],
"source": [
"reconstructed_data_high= high_re\n",
"# # 打印重构的原始数据\n",
"print(reconstructed_data_high)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0wAAAIjCAYAAAAwSJuMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9d7hVxfX2e2n30ouCFBGRYgMbdqNYMGDv3QjWWAjWqBgTSzTYE4yJxtiNRn9RNLE3LKgRC/ZYUBFRUSxI73d/f5xvnztn311mrTUze597530enns456w9c/aePbPe9a41uyYIggAeHh4eHh4eHh4eHh4ejdAi7w54eHh4eHh4eHh4eHgUFZ4weXh4eHh4eHh4eHh4JMATJg8PDw8PDw8PDw8PjwR4wuTh4eHh4eHh4eHh4ZEAT5g8PDw8PDw8PDw8PDwS4AmTh4eHh4eHh4eHh4dHAjxh8vDw8PDw8PDw8PDwSIAnTB4eHh4eHh4eHh4eHgnwhMnDw8PDw8PDw8PDwyMBnjB5eHhooaamBhdeeGHe3YjF2muvjT333DPvbjRr7Ljjjthxxx1ZtmPGjMHaa69ttD9R3HbbbaipqcHnn39utR0KinxPecjx3HPPoaamBs8995z4WEUcvyp23HFHDBkyJPN7n3/+OWpqanDbbbeR2whtr7rqKkYPPTxk8ITJw0MTn376KX75y19inXXWQV1dHTp16oTtttsOEydOxJIlS/LunodBLF68GBdeeKERR6foePjhhzFq1CisttpqqKurw+DBg3HWWWfhhx9+yLtrHs0Af/jDH/Dggw82+z7YxN13340//elPeXfDw6Oq0SrvDnh4VAMeeeQRHHTQQaitrcVRRx2FIUOGYPny5XjxxRfx61//Gu+//z5uvPHGvLtpFUuWLEGrVs1jyli8eDEuuugiAGCrJtWAs846C1dffTU23nhjnHPOOejWrRumTZuG6667Dvfccw+eeeYZrLvuulrHevLJJ9n9+Pvf/476+nq2vUf14g9/+AMOPPBA7Lvvvs26DzZx991347333sNpp52Wd1fQr18/LFmyBK1bt867Kx4eJDQP78fDQ4AZM2bg0EMPRb9+/TB58mT06tWr/Nkpp5yCTz75BI888kiOPbSH+vp6LF++HHV1dairq8u7Ox4G8c9//hNXX301DjnkENx1111o2bJl+bMxY8Zgp512wkEHHYRp06alEuXFixejXbt2aNOmDbsv3nmiYenSpWjTpg1atGheSSKLFi1C+/bt8+6GhwA1NTV+LfGoSjSv2dbDg4ErrrgCCxcuxM0331xBlkIMHDgQp556avn/K1euxO9//3sMGDAAtbW1WHvttXHeeedh2bJlFXZh3c1zzz2HzTffHG3btsXQoUPLaWCTJk3C0KFDUVdXh2HDhuHNN9+ssB8zZgw6dOiAzz77DCNHjkT79u3Ru3dvXHzxxQiCoOK7V111FbbddlusttpqaNu2LYYNG4b77ruv0W+pqanB2LFjcdddd2HDDTdEbW0tHn/88fJnar3FggULcNppp2HttddGbW0tevTogV133RXTpk2rOOa//vUvDBs2DG3btsXqq6+OI488El999VXsb/nqq6+w7777okOHDujevTvOOussrFq1KuHKNMaTTz6JTTbZBHV1ddhggw0wadKkRt/56aefcNppp6Fv376ora3FwIEDcfnll5cVjs8//xzdu3cHAFx00UWoqakp//b//Oc/qKmpwTvvvFM+3v3334+amhrsv//+Fe2sv/76OOSQQyre+8c//lE+F926dcOhhx6KWbNmNerj1KlTMWrUKHTu3Bnt2rXD8OHD8dJLL1V858ILL0RNTQ0++eQTjBkzBl26dEHnzp1x9NFHY/HixZnn6qKLLkLXrl1x4403VpAlANhyyy1xzjnn4N13360YJ2GdwhtvvIEddtgB7dq1w3nnnVf+LKrGzZw5E3vvvTfat2+PHj164PTTT8cTTzzRqK4jWsOk1irceOON5Xtpiy22wGuvvVbRxjvvvIMxY8aUU2V79uyJY445hp1SqHs8yvlftmwZTj/9dHTv3h0dO3bE3nvvjS+//FKrP2EdzD333IPzzz8fffr0Qbt27TB//nwAemMFAL766isce+yx6N27N2pra9G/f3+cdNJJWL58efk7n332GQ466CB069YN7dq1w9Zbb90oGBT25//+7/9w6aWXYs0110RdXR122WUXfPLJJxXfnT59Og444AD07NkTdXV1WHPNNXHooYdi3rx5AEpzyqJFi3D77beX77MxY8ZUnN///e9/OPzww9G1a1f87Gc/A5BcLxdXC1dfX4+JEyeW59Lu3btj1KhReP311zP7EJ63Y445BmussQZqa2ux4YYb4pZbbmnU9pdffol99923YqxH53xdvP/++9h5553Rtm1brLnmmrjkkktiFdh///vf2GOPPcrXdMCAAfj9739fMWfuuOOOeOSRRzBz5szy7wvP0fLly/G73/0Ow4YNQ+fOndG+fXtsv/32ePbZZ1n9BoD//e9/2GmnndCuXTv06dMHV1xxRcXnSTVM//rXv7DBBhugrq4OQ4YMwQMPPJBa25g1L3h4mIZXmDw8MvDQQw9hnXXWwbbbbqv1/eOOOw633347DjzwQJx55pmYOnUqJkyYgA8++AAPPPBAxXc/+eQTHH744fjlL3+JI488EldddRX22msv3HDDDTjvvPNw8sknAwAmTJiAgw8+GB999FFFVHnVqlUYNWoUtt56a1xxxRV4/PHHccEFF2DlypW4+OKLy9+bOHEi9t57bxxxxBFYvnw57rnnHhx00EF4+OGHsccee1T0afLkyfi///s/jB07FquvvnrignXiiSfivvvuw9ixY7HBBhvghx9+wIsvvogPPvgAm222GYBSofLRRx+NLbbYAhMmTMC3336LiRMn4qWXXsKbb76JLl26VPyWkSNHYquttsJVV12Fp59+GldffTUGDBiAk046KfO8T58+HYcccghOPPFEjB49GrfeeisOOuggPP7449h1110BlNSQ4cOH46uvvsIvf/lLrLXWWnj55Zcxfvx4zJ49G3/605/QvXt3XH/99TjppJOw3377lYnQRhtthDXXXBM1NTV44YUXsNFGGwEApkyZghYtWuDFF18s9+W7777Dhx9+iLFjx5bfu/TSS/Hb3/4WBx98MI477jh89913+POf/4wddtih4lxMnjwZu+22G4YNG4YLLrgALVq0wK233oqdd94ZU6ZMwZZbblnxuw8++GD0798fEyZMwLRp03DTTTehR48euPzyy1PP1UcffYQxY8agU6dOsd856qijcMEFF+Dhhx/GoYceWn7/hx9+wG677YZDDz0URx55JNZYY41Y+0WLFmHnnXfG7Nmzceqpp6Jnz564++67Sc7Y3XffjQULFuCXv/wlampqcMUVV2D//ffHZ599VlalnnrqKXz22Wc4+uij0bNnz3J67Pvvv49XXnkFNTU12u1xjqdz/o877jj84x//wOGHH45tt90WkydPbnTfZeH3v/892rRpg7POOgvLli1DmzZttMfK119/jS233BI//fQTTjjhBKy33nr46quvcN9992Hx4sVo06YNvv32W2y77bZYvHgxxo0bh9VWWw2333479t57b9x3333Yb7/9Kvpz2WWXoUWLFjjrrLMwb948XHHFFTjiiCMwdepUACVnfOTIkVi2bBl+9atfoWfPnvjqq6/w8MMP46effkLnzp1x55134rjjjsOWW26JE044AQAwYMCAinYOOuggDBo0CH/4wx8aBYJ0cOyxx+K2227DbrvthuOOOw4rV67ElClT8Morr2DzzTdP7cO3336LrbfeuhxI6t69Ox577DEce+yxmD9/fjnFbcmSJdhll13wxRdfYNy4cejduzfuvPNOTJ48mdzfb775BjvttBNWrlyJc889F+3bt8eNN96Itm3bNvrubbfdhg4dOuCMM85Ahw4dMHnyZPzud7/D/PnzceWVVwIAfvOb32DevHn48ssv8cc//hEA0KFDBwDA/PnzcdNNN+Gwww7D8ccfjwU
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# # 假设你已经有了原始数据和重构数据\n",
"# # 原始数据\n",
"original_data = data1['Power'].values\n",
"\n",
"# # 创建时间序列(假设时间序列与数据对应)\n",
"time = range(len(original_data))\n",
"\n",
"# # 创建画布和子图\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# # 绘制原始数据\n",
"# plt.plot(time, original_data, label='Original Data', color='blue')\n",
"\n",
"# # 绘制重构数据\n",
"plt.plot(reconstructed_data_high[90000:], label='Reconstructed Data', color='red')\n",
"\n",
"# # 添加标题和标签\n",
"plt.title('Comparison between Original and reconstructed_data_high')\n",
"plt.xlabel('Time')\n",
"plt.ylabel('Power')\n",
"plt.legend()\n",
"\n",
"# # 显示图形\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data3=data1.iloc[:,:5]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Temp Humidity GHI DHI Rainfall column_name\n",
"0 19.779453 40.025826 3.232706 1.690531 0.0 -1.460307\n",
"1 19.714937 39.605961 3.194991 1.576346 0.0 -1.460504\n",
"2 19.549330 39.608631 3.070866 1.576157 0.0 -1.460698\n",
"3 19.405870 39.680702 3.038623 1.482489 0.0 -1.460886\n",
"4 19.387363 39.319881 2.656474 1.134153 0.0 -1.461071\n",
"... ... ... ... ... ... ...\n",
"104251 13.303740 34.212711 1.210789 0.787026 0.0 -1.663370\n",
"104252 13.120920 34.394939 2.142980 1.582670 0.0 -1.664516\n",
"104253 12.879215 35.167400 1.926214 1.545889 0.0 -1.665650\n",
"104254 12.915867 35.359989 1.317695 0.851529 0.0 -1.666774\n",
"104255 13.134816 34.500034 1.043269 0.597816 0.0 -1.667887\n",
"\n",
"[104256 rows x 6 columns]\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"# # 创建data3和imf1_array对应的DataFrame\n",
"data3_df = pd.DataFrame(data3)\n",
"imf1_df = pd.DataFrame(reconstructed_data_high)\n",
"\n",
"# # 合并data3_df和imf1_df\n",
"merged_df = pd.concat([data3_df, imf1_df], axis=1)\n",
"\n",
"merged_df = merged_df.iloc[:104256]\n",
"\n",
"# # 打印合并后的表\n",
"print(merged_df)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104256, 6)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"merged_df.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(104256, 6)\n"
]
}
],
"source": [
"# 使用MinMaxScaler进行归一化\n",
"scaler = MinMaxScaler(feature_range=(0, 1))\n",
"scaledData1 = scaler.fit_transform(merged_df)\n",
"print(scaledData1.shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 1 2 3 4 5 0(t-96) \\\n",
"96 0.555631 0.349673 0.190042 0.040558 0.0 0.250386 0.490360 \n",
"97 0.564819 0.315350 0.211335 0.044613 0.0 0.268375 0.489088 \n",
"98 0.576854 0.288321 0.229657 0.047549 0.0 0.286165 0.485824 \n",
"99 0.581973 0.268243 0.247775 0.053347 0.0 0.303808 0.482997 \n",
"100 0.586026 0.264586 0.266058 0.057351 0.0 0.321484 0.482632 \n",
"\n",
" 1(t-96) 2(t-96) 3(t-96) ... 2(t+2) 3(t+2) 4(t+2) 5(t+2) \\\n",
"96 0.369105 0.002088 0.002013 ... 0.229657 0.047549 0.0 0.286165 \n",
"97 0.364859 0.002061 0.001839 ... 0.247775 0.053347 0.0 0.303808 \n",
"98 0.364886 0.001973 0.001839 ... 0.266058 0.057351 0.0 0.321484 \n",
"99 0.365615 0.001950 0.001697 ... 0.282900 0.060958 0.0 0.338338 \n",
"100 0.361965 0.001679 0.001167 ... 0.299668 0.065238 0.0 0.355108 \n",
"\n",
" 0(t+3) 1(t+3) 2(t+3) 3(t+3) 4(t+3) 5(t+3) \n",
"96 0.581973 0.268243 0.247775 0.053347 0.0 0.303808 \n",
"97 0.586026 0.264586 0.266058 0.057351 0.0 0.321484 \n",
"98 0.590772 0.258790 0.282900 0.060958 0.0 0.338338 \n",
"99 0.600396 0.249246 0.299668 0.065238 0.0 0.355108 \n",
"100 0.607019 0.247850 0.313694 0.066189 0.0 0.372185 \n",
"\n",
"[5 rows x 600 columns]\n"
]
}
],
"source": [
"n_steps_in =96 #历史时间长度\n",
"n_steps_out=3#预测时间长度\n",
"processedData1 = time_series_to_supervised(scaledData1,n_steps_in,n_steps_out)\n",
"print(processedData1.head())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"data_x = processedData1.loc[:,'0(t-96)':'5(t-1)']\n",
"data_y = processedData1.loc[:,'5(t+3)']"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104157, 576)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_x.shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"96 0.303808\n",
"97 0.321484\n",
"98 0.338338\n",
"99 0.355108\n",
"100 0.372185\n",
" ... \n",
"104248 0.023869\n",
"104249 0.023687\n",
"104250 0.023507\n",
"104251 0.023329\n",
"104252 0.023153\n",
"Name: 5(t+3), Length: 104157, dtype: float64"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(104157,)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(83325, 96, 6) (83325,) (10417, 96, 6) (10417,) (10415, 96, 6) (10415,)\n"
]
}
],
"source": [
"# 计算训练集、验证集和测试集的大小\n",
"train_size = int(len(data_x) * 0.8)\n",
"test_size = int(len(data_x) * 0.1)\n",
"val_size = len(data_x) - train_size - test_size\n",
"\n",
"# 计算训练集、验证集和测试集的索引范围\n",
"train_indices = range(train_size)\n",
"val_indices = range(train_size, train_size + val_size)\n",
"test_indices = range(train_size + val_size, len(data_x))\n",
"\n",
"# 根据索引范围划分数据集\n",
"train_X1 = data_x.iloc[train_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"val_X1 = data_x.iloc[val_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"test_X1 = data_x.iloc[test_indices].values.reshape((-1, n_steps_in, scaledData1.shape[1]))\n",
"train_y = data_y.iloc[train_indices].values\n",
"val_y = data_y.iloc[val_indices].values\n",
"test_y = data_y.iloc[test_indices].values\n",
"\n",
"# reshape input to be 3D [samples, timesteps, features]\n",
"train_X = train_X1.reshape((train_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"val_X = val_X1.reshape((val_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"test_X = test_X1.reshape((test_X1.shape[0], n_steps_in, scaledData1.shape[1]))\n",
"\n",
"print(train_X.shape, train_y.shape, val_X.shape, val_y.shape, test_X.shape, test_y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(83325, 96, 6)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_X1.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From d:\\Anaconda3\\lib\\site-packages\\keras\\src\\backend\\tensorflow\\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">6</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv1D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">832</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ conv1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling1D</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">95</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">49,920</span> │ max_pooling1d[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,304</span> │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">AttentionWithImpr…</span> │ <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">8</span>, │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)] │ │ bidirectional[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ attention_with_i… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePool…</span> │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">129</span> │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m832\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ max_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mMaxPooling1D\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ bidirectional │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m95\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,920\u001b[0m │ max_pooling1d[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ attention_with_imp… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m66,304\u001b[0m │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mAttentionWithImpr…\u001b[0m │ \u001b[38;5;34m128\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m, │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)] │ │ bidirectional[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention_with_i… │\n",
"│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │ global_average_p… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,185</span> (457.75 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,185\u001b[0m (457.75 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Conv1D, Bidirectional, GlobalAveragePooling1D, Dense, GRU, MaxPooling1D\n",
"from tensorflow.keras.models import Model\n",
"\n",
"class AttentionWithImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, num_heads, max_len=5000):\n",
" super(AttentionWithImproveRelativePositionEncoding, self).__init__()\n",
" self.num_heads = num_heads\n",
" self.d_model = d_model\n",
" self.max_len = max_len\n",
" self.wq = tf.keras.layers.Dense(d_model)\n",
" self.wk = tf.keras.layers.Dense(d_model)\n",
" self.wv = tf.keras.layers.Dense(d_model)\n",
" self.dense = tf.keras.layers.Dense(d_model)\n",
" self.position_encoding = ImproveRelativePositionEncoding(d_model)\n",
"\n",
" def call(self, v, k, q, mask=None):\n",
" batch_size = tf.shape(q)[0]\n",
" q = self.wq(q)\n",
" k = self.wk(k)\n",
" v = self.wv(v)\n",
"\n",
" # Adding position encoding\n",
" k += self.position_encoding(k)\n",
" q += self.position_encoding(q)\n",
"\n",
" q = self.split_heads(q, batch_size)\n",
" k = self.split_heads(k, batch_size)\n",
" v = self.split_heads(v, batch_size)\n",
"\n",
" scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)\n",
" scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])\n",
" concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))\n",
" output = self.dense(concat_attention)\n",
" return output, attention_weights\n",
"\n",
" def split_heads(self, x, batch_size):\n",
" x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_model // self.num_heads))\n",
" return tf.transpose(x, perm=[0, 2, 1, 3])\n",
"\n",
" def scaled_dot_product_attention(self, q, k, v, mask):\n",
" matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
" dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
" scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
"\n",
" if mask is not None:\n",
" scaled_attention_logits += (mask * -1e9)\n",
"\n",
" attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
" output = tf.matmul(attention_weights, v)\n",
" return output, attention_weights\n",
"\n",
"class ImproveRelativePositionEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, d_model, max_len=5000):\n",
" super(ImproveRelativePositionEncoding, self).__init__()\n",
" self.max_len = max_len\n",
" self.d_model = d_model\n",
" # Introduce learnable parameters u and v\n",
" self.u = self.add_weight(shape=(self.d_model,), initializer=tf.keras.initializers.HeNormal(), trainable=True)\n",
" self.v = self.add_weight(shape=(self.d_model,), initializer=tf.keras.initializers.HeNormal(), trainable=True)\n",
"\n",
" def build(self, input_shape):\n",
" super(ImproveRelativePositionEncoding, self).build(input_shape)\n",
"\n",
" def call(self, inputs):\n",
" seq_length = tf.shape(inputs)[1]\n",
" pos_encoding = self.relative_positional_encoding(seq_length, self.d_model)\n",
"\n",
" # Adjusting relative position encoding with parameters\n",
" pe_with_params = pos_encoding * self.u + pos_encoding * self.v\n",
" return inputs + pe_with_params\n",
"\n",
" def relative_positional_encoding(self, position, d_model):\n",
" pos = tf.range(position, dtype=tf.float32)\n",
" i = tf.range(d_model, dtype=tf.float32)\n",
"\n",
" angles = 1 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(d_model, tf.float32))\n",
" angle_rads = tf.einsum('i,j->ij', pos, angles)\n",
"\n",
" angle_rads_sin = tf.sin(angle_rads[:, 0::2])\n",
" angle_rads_cos = tf.cos(angle_rads[:, 1::2])\n",
"\n",
" pos_encoding = tf.stack([angle_rads_sin, angle_rads_cos], axis=2)\n",
" pos_encoding = tf.reshape(pos_encoding, [1, position, d_model])\n",
"\n",
" return pos_encoding\n",
"\n",
"def PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads):\n",
" inputs = Input(shape=input_shape)\n",
" # CNN layer\n",
" cnn_layer = Conv1D(filters=64, kernel_size=2, activation='relu')(inputs)\n",
" cnn_layer = MaxPooling1D(pool_size=1)(cnn_layer)\n",
" gru_output = Bidirectional(GRU(gru_units, return_sequences=True))(cnn_layer)\n",
"\n",
" # Apply Self-Attention\n",
" self_attention = AttentionWithImproveRelativePositionEncoding(d_model=gru_units*2, num_heads=num_heads)\n",
" gru_output, _ = self_attention(gru_output, gru_output, gru_output, mask=None)\n",
"\n",
" pool1 = GlobalAveragePooling1D()(gru_output)\n",
" output = Dense(1)(pool1)\n",
"\n",
" return Model(inputs=inputs, outputs=output)\n",
"\n",
"input_shape = (96, 6)\n",
"gru_units = 64\n",
"num_heads = 8\n",
"\n",
"# Create model\n",
"model = PosConv1biGRUWithSelfAttention(input_shape, gru_units, num_heads)\n",
"model.compile(optimizer='adam', loss='mse')\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m114s\u001b[0m 86ms/step - loss: 0.0116 - val_loss: 0.0025\n",
"Epoch 2/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m109s\u001b[0m 84ms/step - loss: 0.0016 - val_loss: 0.0024\n",
"Epoch 3/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m125s\u001b[0m 96ms/step - loss: 0.0016 - val_loss: 0.0023\n",
"Epoch 4/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m114s\u001b[0m 87ms/step - loss: 0.0016 - val_loss: 0.0025\n",
"Epoch 5/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 79ms/step - loss: 0.0015 - val_loss: 0.0025\n",
"Epoch 6/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m111s\u001b[0m 85ms/step - loss: 0.0015 - val_loss: 0.0025\n",
"Epoch 7/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m109s\u001b[0m 84ms/step - loss: 0.0014 - val_loss: 0.0027\n",
"Epoch 8/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m108s\u001b[0m 83ms/step - loss: 0.0015 - val_loss: 0.0024\n",
"Epoch 9/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 109ms/step - loss: 0.0014 - val_loss: 0.0023\n",
"Epoch 10/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m182s\u001b[0m 140ms/step - loss: 0.0014 - val_loss: 0.0025\n",
"Epoch 11/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 110ms/step - loss: 0.0014 - val_loss: 0.0026\n",
"Epoch 12/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m116s\u001b[0m 89ms/step - loss: 0.0014 - val_loss: 0.0023\n",
"Epoch 13/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m116s\u001b[0m 89ms/step - loss: 0.0014 - val_loss: 0.0023\n",
"Epoch 14/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m135s\u001b[0m 104ms/step - loss: 0.0014 - val_loss: 0.0024\n",
"Epoch 15/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m112s\u001b[0m 86ms/step - loss: 0.0014 - val_loss: 0.0024\n",
"Epoch 16/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 81ms/step - loss: 0.0013 - val_loss: 0.0024\n",
"Epoch 17/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m109s\u001b[0m 84ms/step - loss: 0.0013 - val_loss: 0.0024\n",
"Epoch 18/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m125s\u001b[0m 96ms/step - loss: 0.0013 - val_loss: 0.0024\n",
"Epoch 19/100\n",
"\u001b[1m1302/1302\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m179s\u001b[0m 137ms/step - loss: 0.0013 - val_loss: 0.0025\n"
]
}
],
"source": [
"# Compile and train the model\n",
"model.compile(optimizer='adam', loss='mean_squared_error')\n",
"from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
"\n",
"# 定义早停机制\n",
"early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, mode='min')\n",
"\n",
"# 拟合模型,并添加早停机制和模型检查点\n",
"history = model.fit(train_X, train_y, epochs=100, batch_size=64, validation_data=(val_X, val_y), \n",
" callbacks=[early_stopping])\n",
"\n",
"# 将预测结果的形状修改为与原始数据相同的形状"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj0AAAGdCAYAAAD5ZcJyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABaGUlEQVR4nO3deXzT5eEH8E/u9EyPQA8oUKECCgIWqEW8O4ui2Dnl0MkhA+dwyq9zKo5DJ7MTxIHIxAvBORTZEJ26jlJFNykol4gKApaWK4VeSZu2SZt8f388SdrQM6W52s/79cqryTdPvnm+DW0/PKdMkiQJRERERN2c3N8VICIiIvIFhh4iIiLqERh6iIiIqEdg6CEiIqIegaGHiIiIegSGHiIiIuoRGHqIiIioR2DoISIioh5B6e8KBBK73Y4zZ84gIiICMpnM39UhIiKiDpAkCVVVVUhMTIRc3np7DkNPE2fOnEFSUpK/q0FERESdcPLkSfTt27fV5xl6moiIiAAgvmmRkZF+rg0RERF1hMlkQlJSkuvveGsYeppwdmlFRkYy9BAREQWZ9oamcCAzERER9QgMPURERNQjMPQQERFRj8AxPURERD5gs9lQX1/v72oEJYVCAaVSedHLyTD0EBEReVl1dTVOnToFSZL8XZWgFRoaioSEBKjV6k6fg6GHiIjIi2w2G06dOoXQ0FD06tWLi996SJIkWK1WnD9/HoWFhUhJSWlzAcK2MPQQERF5UX19PSRJQq9evRASEuLv6gSlkJAQqFQqFBUVwWq1QqvVduo8HMhMRETkA2zhuTidbd1xO0cX1IOIiIgo4DH0EBERUY/A0ENEREReNWDAAKxcudLf1eBAZiIiImru+uuvx8iRI7skrHz99dcICwu7+EpdJIYeH8j/oQT/PVqKqy6JxYRh8f6uDhER0UWTJAk2mw1KZftRolevXj6oUfvYveUDX5+owPqdJ7DrpzJ/V4WIiPxMkiTUWBv8cuvo4ogzZ87E559/jlWrVkEmk0Emk2H9+vWQyWT497//jdTUVGg0Gvzvf//D8ePHcccddyAuLg7h4eEYM2YMtm/f7na+C7u3ZDIZXn/9dfz85z9HaGgoUlJS8OGHH3blt7lFbOnxAX24WD2y3Gz1c02IiMjfauttuGzxf/zy3t//MROh6vb/9K9atQo//vgjhg0bhj/+8Y8AgO+++w4A8MQTT+D555/HJZdcgujoaJw8eRK33nor/vSnP0Gj0eCtt97C7bffjiNHjqBfv36tvsfTTz+NZcuWYfny5Vi9ejXuvfdeFBUVISYmpmsutgVs6fGBWEfoKTNb/FwTIiKi9ul0OqjVaoSGhiI+Ph7x8fFQKBQAgD/+8Y/42c9+hoEDByImJgYjRozAAw88gGHDhiElJQXPPPMMBg4c2G7LzcyZMzFt2jQMGjQIzz77LKqrq/HVV1959brY0uMDMWEaAEBZNVt6iIh6uhCVAt//MdNv732xRo8e7fa4uroaTz31FD7++GOcPXsWDQ0NqK2tRXFxcZvnueKKK1z3w8LCEBkZiXPnzl10/drC0OMDsWHOlh6GHiKink4mk3WoiylQXTgL69FHH0VeXh6ef/55DBo0CCEhIbjrrrtgtbb9N0+lUrk9lslksNvtXV7fpoL3ux5E9OGipafcbIXdLkEu51LkREQU2NRqNWw2W7vlvvzyS8ycORM///nPAYiWnxMnTni5dp3DMT0+EB0m0qzNLsFYW+/n2hAREbVvwIAB2L17N06cOIHS0tJWW2FSUlKwZcsWHDhwAN988w3uuecer7fYdBZDjw9olApEaEWjGru4iIgoGDz66KNQKBS47LLL0KtXr1bH6LzwwguIjo7GuHHjcPvttyMzMxNXXnmlj2vbMTKpo5P2ewCTyQSdTgej0YjIyMguPfcNz+9AYakZm+ZehbRLYrv03EREFLjq6upQWFiI5ORkaLVaf1cnaLX1fezo32+29PhIDAczExER+RVDj49wBhcREZF/MfT4SGy4c60eLlBIRETkDww9PuJq6eEChURERH7B0OMjsdx/i4iIyK8YenzE2b1Vyu4tIiIiv2Do8REOZCYiIvIvhh4fYfcWERGRfzH0+EisY6f1ihorGmyBuTw3ERFRd8bQ4yPRoWL/LUkCKmq4/xYREQW266+/HvPnz++y882cORNZWVlddr7OYOjxEaVC7go+7OIiIiLyPYYeH+IChUREFAxmzpyJzz//HKtWrYJMJoNMJsOJEydw6NAh3HLLLQgPD0dcXBzuu+8+lJaWul73j3/8A8OHD0dISAhiY2ORkZEBs9mMp556Chs2bMAHH3zgOt+OHTt8fl2dCj1r1qzBgAEDoNVqkZaWhq+++qrN8ps3b8aQIUOg1WoxfPhwfPLJJ27PS5KExYsXIyEhASEhIcjIyMDRo0dbPJfFYsHIkSMhk8lw4MABt+cOHjyIa665BlqtFklJSVi2bFlnLs9rnPtvlbKlh4io55IkwGr2z62De4yvWrUK6enpmDNnDs6ePYuzZ88iIiICN954I0aNGoU9e/YgNzcXJSUlmDx5MgDg7NmzmDZtGu6//3788MMP2LFjB+68805IkoRHH30UkydPxoQJE1znGzdunDe/yy1SevqCTZs2ITs7G2vXrkVaWhpWrlyJzMxMHDlyBL17925WfufOnZg2bRpycnJw2223YePGjcjKysK+ffswbNgwAMCyZcvw4osvYsOGDUhOTsaiRYuQmZmJ77//vtlOqo899hgSExPxzTffuB03mUy4+eabkZGRgbVr1+Lbb7/F/fffj6ioKMydO9fTy/QKvXMGF1t6iIh6rvoa4NlE/7z3k2cAdVi7xXQ6HdRqNUJDQxEfHw8AWLp0KUaNGoVnn33WVW7dunVISkrCjz/+iOrqajQ0NODOO+9E//79AQDDhw93lQ0JCYHFYnGdzx88bul54YUXMGfOHMyaNQuXXXYZ1q5di9DQUKxbt67F8qtWrcKECRPw+9//HkOHDsUzzzyDK6+8Ei+99BIA0cqzcuVKLFy4EHfccQeuuOIKvPXWWzhz5gy2bt3qdq5///vf2LZtG55//vlm7/P3v/8dVqsV69atw+WXX46pU6fi4YcfxgsvvODpJXqNcwYX1+ohIqJg88033+Czzz5DeHi46zZkyBAAwPHjxzFixAjcdNNNGD58OO6++2689tprqKio8HOt3XnU0mO1WrF3714sWLDAdUwulyMjIwMFBQUtvqagoADZ2dluxzIzM12BprCwEAaDARkZGa7ndTod0tLSUFBQgKlTpwIASkpKMGfOHGzduhWhoaEtvs+1114LtVrt9j7PPfccKioqEB0d3ew1FosFFktjq4vJZOrAd6HzXN1b3H+LiKjnUoWKFhd/vXcnVVdX4/bbb8dzzz3X7LmEhAQoFArk5eVh586d2LZtG1avXo0//OEP2L17N5KTky+m1l3Go5ae0tJS2Gw2xMXFuR2Pi4uDwWBo8TUGg6HN8s6vbZWRJAkzZ87Er3/9a4wePdqj92n6HhfKycmBTqdz3ZKSklos11Vc3Vtmdm8REfVYMpnoYvLHTSbrcDXVajVsNpvr8ZVXXonvvvsOAwYMwKBBg9xuYWFhjkuT4eqrr8bTTz+N/fv3Q61W4/3332/xfP4QFLO3Vq9ejaqqKrcWpq6wYMECGI1G1+3kyZNdev4LNc7eYksPEREFtgEDBmD37t04ceIESktLMW/ePJSXl2PatGn4+uuvcfz4cfznP//BrFmzYLPZsHv3bjz77LPYs2cPiouLsWXLFpw/fx5Dhw51ne/gwYM4cuQISktLUV/v+zXrPAo9er0eCoUCJSUlbsdLSkpaHZgUHx/fZnnn17bKfPrppygoKIBGo4FSqcSgQYMAAKNHj8aMGTPafJ+m73EhjUaDyMhIt5s3cf8tIiIKFo8++igUCgUuu+wy9OrVC1arFV9++SVsNhtuvvlmDB8+HPPnz0dUVBTkcjkiIyPxxRdf4NZbb8Wll16KhQsXYsWKFbjlllsAAHPmzMHgwYMxevRo9OrVC19++aXPr8mjMT1qtRqpqanIz89
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'], label='train')\n",
"plt.plot(history.history['val_loss'], label='test')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m326/326\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m9s\u001b[0m 25ms/step\n"
]
}
],
"source": [
"# 预测\n",
"lstm_pred = model.predict(test_X)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10415, 1)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lstm_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10415,)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"test_y1=test_y.reshape(10415,1)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.06037087],\n",
" [0.06032172],\n",
" [0.06027242],\n",
" ...,\n",
" [0.02350742],\n",
" [0.0233294 ],\n",
" [0.02315312]])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_y1"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"results1 = np.broadcast_to(lstm_pred, (10415, 6))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"test_y2 = np.broadcast_to(test_y1, (10415, 6))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# 反归一化\n",
"inv_forecast_y = scaler.inverse_transform(results1)\n",
"inv_test_y = scaler.inverse_transform(test_y2)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-2.03686661, 9.49929284, 85.31799419, 40.07645259, 1.43682734,\n",
" -1.43294754],\n",
" [-2.03936077, 9.49443222, 85.2487593 , 40.04411781, 1.43565736,\n",
" -1.43325785],\n",
" [-2.04186187, 9.48955805, 85.17933142, 40.0116929 , 1.43448413,\n",
" -1.43356904],\n",
" ...,\n",
" [-3.90720611, 5.85436487, 33.39945635, 15.82893159, 0.5594767 ,\n",
" -1.66565038],\n",
" [-3.91623795, 5.83676359, 33.14874276, 15.71184079, 0.55523999,\n",
" -1.6667741 ],\n",
" [-3.92518186, 5.81933364, 32.90046971, 15.5958898 , 0.55104453,\n",
" -1.66788688]])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_test_y"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test RMSE: 0.217\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQoAAAKTCAYAAABRkzVdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5xseV3n/9epXNXdVdU53Jzm3skMMwMMcQARMCysmFZZQFdBBAETOir6E9Rh1cVlXUV2DSAo4KKCCgIiEgYYBibfiXfm5s6pcq46vz++p6rD7Ru7q0511fv5eMyj4u3+3HnMVJ3zOZ9g2bZtIyIiIiIiIiIiIl3N43YAIiIiIiIiIiIi4j4lCkVERERERERERESJQhEREREREREREVGiUERERERERERERFCiUERERERERERERFCiUERERERERERERFCiUERERERERERERACf2wFcSK1WY2pqir6+PizLcjscERERERERERGRbcW2bdLpNBMTE3g8F64ZbOtE4dTUFLt27XI7DBERERERERERkW3tzJkz7Ny584LvaetEYV9fH2D+ItFo1OVoREREREREREREtpdUKsWuXbsaebYLaetEYb3dOBqNKlEoIiIiIiIiIiJyhS5lrJ+WmYiIiIiIiIiIiIgShSIiIiIiIiIiIqJEoYiIiIiIiIiIiNDmMwpFRERERERERKQ7VatVyuWy22FsC4FAAI9n8/WAShSKiIiIiIiIiEjbsG2bmZkZEomE26FsGx6Ph3379hEIBDb1c5QoFBERERERERGRtlFPEo6MjBCJRC5pW283q9VqTE1NMT09ze7duzf170uJQhERERERERERaQvVarWRJBwcHHQ7nG1jeHiYqakpKpUKfr//in+OlpmIiIiIiIiIiEhbqM8kjEQiLkeyvdRbjqvV6qZ+jhKFIiIiIiIiIiLSVtRufHm26t+XEoUiIiIiIiIiIiKiRKGIiIiIiIiIiIgoUSgiIiIiIiIiIiIoUSgiIiIiIiIiIrJpt99+O+94xzvcDmNTlCgUERERERERERFpMtu2qVQqbodxQUoUioiIiIiIiIhIW7Jtm1yp4so/tm1fcpxveMMb+MpXvsL73/9+LMvCsiw+9KEPYVkW//qv/8rNN99MMBjkrrvu4g1veAOvfvWr1/z5d7zjHdx+++2Nx7VajTvvvJN9+/YRDoe58cYb+eQnP7lF/1bPz9f03yAiIiIiIiIiInIF8uUq1/zm51353Y++++VEApeWOnv/+9/Pk08+yXXXXce73/1uAB555BEAfvVXf5U//MM/ZP/+/fT391/Sz7vzzjv56Ec/yp/92Z9x6NAhvvrVr/La176W4eFhXvSiF13ZX+gSKFEoIiIiIiIiIiKyCbFYjEAgQCQSYWxsDIDHH38cgHe/+9287GUvu+SfVSwW+b3f+z2++MUvcttttwGwf/9+7rrrLj74wQ8qUSgiIiIiIiIiIt0n7Pfy6Ltf7trv3gq33HLLZb3/qaeeIpfLnZNcLJVK3HTTTVsS0/koUSgiIiIiIiIiIm3JsqxLbv9tVz09PWseezyec+Yflsvlxv1MJgPAZz7zGXbs2LHmfcFgsElRGtv737SIiIiIiIiIiEgbCAQCVKvVi75veHiYo0ePrnnugQcewO/3A3DNNdcQDAY5ffp0U9uMN6JEoYiIiIiIiIiIyCbt3buXb33rW5w8eZLe3l5qtdqG73vJS17CH/zBH/DXf/3X3HbbbXz0ox/l6NGjjbbivr4+fumXfomf//mfp1ar8fznP59kMsnXv/51otEor3/965v2d/A07SeLiIiIiIiIiIh0iV/6pV/C6/VyzTXXMDw8zOnTpzd838tf/nLe9a538c53vpNbb72VdDrN6173ujXvec973sO73vUu7rzzTq6++mpe8YpX8JnPfIZ9+/Y19e9g2eubottIKpUiFouRTCaJRqNuhyMiIiIiIiIiIk1UKBQ4ceIE+/btIxQKuR3OtnGhf2+Xk19TRaGIiIiIiIiIiIgoUSjSre49tczPfORezizl3A5FRERERERERNqAlpmIdKmf+vC3Wc6VmUzk+eefe77b4YiIiIiIiIiIy1RRKNKllnNlAB6eTLociYiIiIiIiIi0AyUKRbqU32s17hfKVRcjEREREREREZF2oEShSJcK+b2N+6oqFBERERERERElCkW6UKFcJV2oNB4/Pp1yMRoRERERERERaQdKFIp0ocVsac3jhDOvUERERERERES6lxKFIl1oPl1c83hZiUIRERERERGRrqdEoUgXWp8oTORK53mniIiIiIiIiHQLJQpFutA5icK8KgpFREREREREWqVUas+CHSUKRbpQPVE42BMAYFkVhSIiIiIiIiJX7Pbbb+etb30rb33rW4nFYgwNDfGud70L27YB2Lt3L+95z3t43eteRzQa5Y1vfCMAd911Fy94wQsIh8Ps2rWLt73tbWSzWdf+HkoUinShxaxJFB4Y6QUgqRmFIiIiIiIi0o5sG0pZd/5xknyX6sMf/jA+n4977rmH97///bzvfe/jz//8zxuv/+Ef/iE33ngj999/P+9617t4+umnecUrXsFrXvMaHnroIT7xiU9w11138da3vnWr/y1eMp9rv1lEXFPfcrxvsId7TiypolBERERERETaUzkHvzfhzu/+tSkI9Fzy23ft2sUf/dEfYVkWhw8f5uGHH+aP/uiP+Omf/mkAXvKSl/CLv/iLjff/1E/9FD/+4z/OO97xDgAOHTrE//pf/4sXvehFfOADHyAUCm3pX+dSqKJQpAslnZmEuwcjjce12uVdKRERERERERGRFc95znOwLKvx+LbbbuPYsWNUq1UAbrnlljXvf/DBB/nQhz5Eb29v45+Xv/zl1Go1Tpw40dLY61RRKNKF6onCPU6isGZDulAhFvG7GZaIiIiIiIjIWv6Iqexz63dvoZ6etdWJmUyGN73pTbztbW875727d+/e0t99qZQoFOlCKSdRONQbpCfgJVuqspwrKVEoIiIiIiIi7cWyLqv9103f+ta31jy+++67OXToEF6vd8P3P/OZz+TRRx/l4MGDrQjvkqj1WKQL1SsKY2E/8Yg2H4uIiIiIiDvShTIPnU2QLVbcDkVk006fPs0v/MIv8MQTT/Cxj32MP/7jP+btb3/7ed//K7/yK3zjG9/grW99Kw888ADHjh3j05/+tJaZiEjr2LbdSBTGI37iET+TiTyJvDYfi4iIiIhIa73uL+/h/tMJhnoDfPWdLyYSUJpCtq/Xve515PN5nvWsZ+H1enn729/OG9/4xvO+/4YbbuArX/kKv/7rv84LXvACbNvmwIED/MiP/EgLo15L/weKdJlcqUrFWVwSC/vpdyoKE6ooFBERERGRFsoUK9x/OgHAQqbEmaU8h8f63A1KZBP8fj//83/+Tz7wgQ+c89rJkyc3/DO33norX/jCF5oc2aVT67FIl6lXE/q9FmG/tzGXMJFTRaGIiIiIiLTOEzOpNY9TBZ2TiLhNiUKRLpAvVXnnJx/kK0/Or5lPaFkW/U6icFmJQhERERERaaFHp9NrHid1TiLiOrUei3SocrXGOz7+APuGeggHvPzdd87yd98523g9GjYJwnhYrcciIiKyNWzbplqz8XlVjyAiF/folCoKpXN8+ctfdjuELaFEoUiH+sbTi3zm4WkAfviWnee8HqsnCtV6LCIiIlvg7+89y7s+fRSAf/zZ52nOmIhc0NefWuBj95xe81xSCxZFXKdLfSId6um5TOP+XccWznnd2WfSWGayrIpCERER2YRPfOcMuVKVXKnKvz0643Y4ItLm/uDzTwDQH/HzfTeMA5DKV9wMSdqMbdtuh7CtbNW/LyUKRTrU0alk4/5UsnDO6w+fTQArFYW6eiciIiKXKl0o8xufephvn1wCzMnJk7Mrs8YeWddOKCKy3tnlPAB//vpb2NkfAXROIobfb85Rc7mcy5FsL6WSKf7xer2b+jlqPRbpUI9MXvgA/XtvmAAgropCERERuUzv+7cn+ejdp/no3ac5+d7vZS5dXDPGRIlCEbmQYqXKQqYIwP6hXqJ
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 计算均方根误差\n",
"rmse = sqrt(mean_squared_error(inv_test_y[:,5], inv_forecast_y[:,5]))\n",
"print('Test RMSE: %.3f' % rmse)\n",
"#画图\n",
"plt.figure(figsize=(16,8))\n",
"plt.plot(inv_test_y[900:2100,5], label='true')\n",
"plt.plot(inv_forecast_y[900:2100,5], label='pre')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean_squared_error: 0.0011780920849826654\n",
"mean_absolute_error: 0.013530156512489254\n",
"rmse: 0.03432334606332351\n",
"r2 score: 0.9966738024269023\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error, mean_absolute_error # 评价指标\n",
"# 使用sklearn调用衡量线性回归的MSE 、 RMSE、 MAE、r2\n",
"from math import sqrt\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import r2_score\n",
"print('mean_squared_error:', mean_squared_error(lstm_pred, test_y)) # mse)\n",
"print(\"mean_absolute_error:\", mean_absolute_error(lstm_pred, test_y)) # mae\n",
"print(\"rmse:\", sqrt(mean_squared_error(lstm_pred,test_y)))\n",
"print(\"r2 score:\", r2_score(inv_test_y[900:2100], inv_forecast_y[900:2100]))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"df1 = pd.DataFrame(inv_test_y[:,5], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df1.to_csv('xin99939高频re_test(t+3).csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"df2 = pd.DataFrame(inv_forecast_y[:,5], columns=['column_name'])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# 指定文件路径和文件名保存DataFrame到CSV文件中\n",
"df2.to_csv('xin99939高频re_forecast(t+3).csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}