651 KiB
651 KiB
将原始数据的时区进行修改并保存¶
In [18]:
""" 修改时区的数据集,可以作为测试集使用 """
Out[18]:
'\n修改时区的数据集,可以作为测试集使用\n'
In [4]:
import numpy as np import pandas as pd data = pd.read_csv('./datasets/station08.csv') data['date_time'] = pd.to_datetime(data['date_time']) data['date_time'] = data['date_time'] + pd.Timedelta(hours=8)
In [17]:
data.to_csv('./datasets/station08_utf8.csv',index=False)
验证预测结果与真实结果的差别¶
In [65]:
""" 找到测试集中,排名前十的好结果 """
Out[65]:
'\n找到测试集中,排名前十的好结果\n\n'
In [28]:
import os data_path = '/home/xiazj/project_test/Crossformer-master/results/Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0'
In [29]:
pred_data = np.load(os.path.join(data_path,'pred.npy'))
In [30]:
true_data = np.load(os.path.join(data_path,'true.npy'))
In [57]:
mse_list= [] for i in range(len(pred_data)): pred_data_ = np.array(pred_data[i,:,-1:]).reshape(96) true_data_ = np.array(true_data[i,:,-1:]).reshape(96) pred_data_ = np.clip(pred_data_, 0, None) mse = np.mean((pred_data_ - true_data_) ** 2) mse_list.append(mse)
In [75]:
# 最大 top_10_indices = np.argpartition(mse_list, 10)[:10] # 最小 # top_10_indices = np.argpartition(mse_list, -10)[-10:]
In [104]:
true_data.shape
Out[104]:
(6529, 96, 14)
In [76]:
top_10_indices
Out[76]:
array([4717, 4703, 4735, 4702, 4701, 4740, 4736, 4712, 4711, 4716])
In [77]:
import matplotlib.pyplot as plt
In [78]:
for i in range(len(pred_data)): if i in top_10_indices: print(i) pred_data_ = np.array(pred_data[i,:,-1:]).reshape(96) true_data_ = np.array(true_data[i,:,-1:]).reshape(96) pred_data_ = np.clip(pred_data_, 0, None) # 设置 x 轴范围 x = np.arange(1, 97) # x 轴从 1 到 14 # 遍历每个特征 plt.plot(x,pred_data_, label=f'Predicted Feature ', linestyle='--') plt.plot(x,true_data_, label=f'True Feature ') # 添加图例 plt.legend() plt.title('Comparison of Predicted and True Values for Each Feature') plt.xlabel('Time Steps') plt.ylabel('Values') plt.axhline(0, color='black', linewidth=0.5, ls='--') # 添加水平线表示零偏差 plt.grid() plt.show()
4701
4702
4703
4711
4712
4716
4717
4735
4736
4740
推理结果可视化¶
In [72]:
real = [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.1170799999999999], [0.226425], [0.556652], [1.090384], [1.735596], [2.418046], [2.505163], [3.456932], [1.469682], [4.7706230000000005], [5.965847], [6.104462], [7.312792], [6.203876], [4.957963], [3.701485], [4.493279], [5.98692], [6.851398], [6.898161], [11.08569], [10.24586], [9.231687], [12.13076], [11.30536], [12.62044], [13.02087], [11.04955], [13.152729999999998], [8.404561], [11.72761], [9.320998], [4.060809], [3.344008], [4.563134], [5.036594], [9.309683], [7.437724], [5.168396], [4.341616999999999], [3.589081], [1.63901], [1.266235], [1.411258], [1.906713], [2.103406], [1.811167], [1.492544], [1.77341], [1.549872], [1.45219], [0.498169], [0.423753], [0.120082], [0.2112409999999999], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]] Predictions = [[-2.86124706e-01], [-1.85412407e-01], [-1.23522758e-01], [-2.36801863e-01], [-1.44170523e-01], [-1.88258648e-01], [-2.88140297e-01], [-3.10372591e-01], [-3.25312853e-01], [-3.01664591e-01], [-3.51567268e-01], [-3.99164915e-01], [-1.98129177e-01], [-2.40688801e-01], [-1.31866693e-01], [-1.35723829e-01], [-1.25133514e-01], [-5.73484898e-02], [-4.16159391e-01], [-3.20464611e-01], [-1.97598696e-01], [3.75599861e-02], [1.71803951e-01], [2.54544497e-01], [-8.46397877e-02], [2.03914165e-01], [4.66168642e-01], [9.69486117e-01], [1.44746566e+00], [1.93339133e+00], [3.05528426e+00], [3.98479152e+00], [4.61557865e+00], [5.74502182e+00], [6.33362198e+00], [7.41310310e+00], [7.41493511e+00], [8.36355972e+00], [8.93667889e+00], [9.44896412e+00], [1.04542589e+01], [1.08695698e+01], [1.06164064e+01], [1.12903938e+01], [1.15168724e+01], [1.20199852e+01], [1.21539640e+01], [1.24546432e+01], [1.18046064e+01], [1.19759865e+01], [1.18456869e+01], [1.18761501e+01], [1.18272791e+01], [1.17170534e+01], [1.13529377e+01], [1.10521955e+01], [1.06786232e+01], [1.05224686e+01], [9.81320190e+00], [9.53697395e+00], [8.66638947e+00], [7.97475529e+00], [7.41621113e+00], [6.65187836e+00], [5.92018986e+00], [5.26671886e+00], [4.15544462e+00], [3.40540767e+00], [2.59274578e+00], [2.10804391e+00], [1.37592804e+00], [1.06501186e+00], [1.13409483e+00], [6.22718215e-01], [1.52500153e-01], [-7.96604156e-03], [-3.04882288e-01], [-6.02218390e-01], [1.50711298e-01], [-1.30321026e-01], [-3.54492903e-01], [-4.41823721e-01], [-3.99953127e-01], [-5.66465855e-01], [-2.37069845e-01], [-3.53769541e-01], [-3.37545872e-01], [-4.05787230e-01], [-3.33610773e-01], [-3.95320892e-01], [-4.33113575e-01], [-3.90184402e-01], [-3.61282110e-01], [-2.99006939e-01], [-2.87155390e-01], [-2.05773354e-01]] # 将数据转换为一维数组 real = np.array(real).flatten() Predictions = np.array(Predictions).flatten() Predictions = np.clip(Predictions, 0, None) # 创建 x 轴数据(索引) x = np.arange(len(real)) # 绘制折线图 plt.figure(figsize=(12, 6)) plt.plot(x, real, label='Real', marker='o', linestyle='-', color='blue') plt.plot(x, Predictions, label='Predictions', marker='x', linestyle='--', color='red') # 添加标题和标签 plt.title('Real vs Predictions', fontsize=16) plt.xlabel('Index', fontsize=12) plt.ylabel('Value', fontsize=12) # 添加图例 plt.legend() # 显示网格 plt.grid(True) # 显示图表 plt.show()
定位好的测试集的开始时间¶
In [84]:
good_data = true_data[4717]
In [85]:
good_data_first = true_data[4717][0]
In [86]:
ori_data = pd.read_csv('./datasets/station08_utf8.csv')
In [97]:
ori_data.columns
Out[97]:
Index(['date_time', 'nwp_globalirrad', 'nwp_directirrad', 'nwp_temperature', 'nwp_humidity', 'nwp_windspeed', 'nwp_winddirection', 'nwp_pressure', 'lmd_totalirrad', 'lmd_diffuseirrad', 'lmd_temperature', 'lmd_pressure', 'lmd_winddirection', 'lmd_windspeed', 'power'], dtype='object')
In [93]:
good_data_first
Out[93]:
array([ 0.0000000e+00, -7.6293945e-06, 2.0690001e+01, 5.8049999e+01, 5.9600000e+00, 2.8688000e+02, 9.4463000e+02, -1.5258789e-05, 0.0000000e+00, 2.2600000e+01, 9.4320001e+02, 2.1600000e+02, 1.1000000e+00, 0.0000000e+00], dtype=float32)
In [109]:
ori_data[(ori_data['nwp_globalirrad'] == good_data_first[0]) & # (ori_data['nwp_directirrad'] == good_data_first[1]) & # (ori_data['nwp_temperature'] == good_data_first[2]) & # (ori_data['nwp_humidity'] == good_data_first[3]) & # (ori_data['nwp_windspeed'] == good_data_first[4]) & # (ori_data['nwp_winddirection'] == good_data_first[5]) & # (ori_data['nwp_pressure'] == good_data_first[6]) & # (ori_data['lmd_totalirrad'] == good_data_first[7]) & (ori_data['lmd_diffuseirrad'] == 0) & (ori_data['lmd_temperature'] == 22.6) & # (ori_data['lmd_pressure'] == good_data_first[10]) & (ori_data['lmd_winddirection'] == 216) & (ori_data['lmd_windspeed'] == 1.1) & (ori_data['power'] == good_data_first[13])]
Out[109]:
date_time | nwp_globalirrad | nwp_directirrad | nwp_temperature | nwp_humidity | nwp_windspeed | nwp_winddirection | nwp_pressure | lmd_totalirrad | lmd_diffuseirrad | lmd_temperature | lmd_pressure | lmd_winddirection | lmd_windspeed | power | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
31213 | 2019-05-25 03:15:00 | 0.0 | 0.0 | 20.69 | 58.05 | 5.96 | 286.88 | 944.63 | 0 | 0 | 22.6 | 943.200012 | 216 | 1.1 | 0.0 |
In [129]:
data_1 = ori_data.loc[31213 - 192 : 31213 + 95]
In [130]:
data_1.shape
Out[130]:
(288, 15)
In [131]:
real = data_1['power'] # 创建 x 轴数据(索引) x = np.arange(288) # 绘制折线图 plt.figure(figsize=(12, 6)) plt.plot(x, real, label='Real', marker='o', linestyle='-', color='blue') # 添加标题和标签 plt.title('Real vs Predictions', fontsize=16) plt.xlabel('Index', fontsize=12) plt.ylabel('Value', fontsize=12) # 添加图例 plt.legend() # 显示网格 plt.grid(True) # 显示图表 plt.show()
In [133]:
data_1.to_csv('./datasets/run_test.csv',index=False)
In [ ]: