111 lines
4.9 KiB
Python
111 lines
4.9 KiB
Python
from guangfufadian.cross_exp.exp_crossformer import Exp_crossformer # 根据实际路径导入
|
||
from guangfufadian.utils.tools import StandardScaler
|
||
import pickle,os,sys
|
||
import torch
|
||
import pandas as pd
|
||
import torch
|
||
import numpy as np
|
||
# 获取当前工作目录
|
||
current_directory = os.getcwd()
|
||
current_directory = os.path.join(current_directory,'guangfufadian')
|
||
|
||
# 设置参数
|
||
class guangfufadian_Args:
|
||
# 在这里定义模型参数
|
||
data_dim = 14
|
||
in_len = 192
|
||
out_len = 96
|
||
seg_len = 6
|
||
win_size = 2
|
||
factor = 10
|
||
d_model = 256
|
||
d_ff = 512
|
||
n_heads = 4
|
||
e_layers = 3
|
||
dropout = 0.2
|
||
use_multi_gpu = False
|
||
use_gpu = True
|
||
device_ids = [0] # 如果使用多 GPU,设置设备 ID
|
||
batch_size = 32
|
||
train_epochs = 100
|
||
patience = 10
|
||
num_workers = 4
|
||
# root_path = '/home/xiazj/project_test/Crossformer-master/datasets/station08_utf8.csv' # 数据路径
|
||
data_path = os.path.join(current_directory,'datasets/station08_utf8.csv') # 数据文件名
|
||
checkpoints = os.path.join(current_directory,'checkpoints') # 模型保存路径
|
||
learning_rate = 0.001 # 学习率
|
||
data_split = 0.8 # 数据划分比例
|
||
scale_statistic = pickle.load(open(os.path.join(checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/scale_statistic.pkl'), 'rb'))
|
||
baseline = False
|
||
gpu = 0
|
||
|
||
# 实例化参数
|
||
guangfufadian_args = guangfufadian_Args()
|
||
|
||
|
||
class ModelInference:
|
||
def __init__(self, model_path, args):
|
||
# 初始化模型
|
||
self.model_experiment = Exp_crossformer(args)
|
||
self.model_experiment.model.load_state_dict(torch.load(model_path)) # 加载模型参数
|
||
self.model_experiment.model.eval() # 设置模型为评估模式
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.model_experiment.model.to(self.device)
|
||
self.scaler = StandardScaler(mean = args.scale_statistic['mean'], std = args.scale_statistic['std'])
|
||
|
||
|
||
"""
|
||
加载推理数据
|
||
如果数据列发生变化,或者列不匹配,报错
|
||
数据长度满足192,截取前192
|
||
数据长度不满足192,前拼接满足数量级
|
||
"""
|
||
def load_pred_data(self,data):
|
||
raw_data = data.values
|
||
length = raw_data.shape[0]
|
||
# 顺序,必须一致才能进行下一步
|
||
check_list = ['date_time', 'nwp_globalirrad', 'nwp_directirrad', 'nwp_temperature',
|
||
'nwp_humidity', 'nwp_windspeed', 'nwp_winddirection', 'nwp_pressure',
|
||
'lmd_totalirrad', 'lmd_diffuseirrad', 'lmd_temperature', 'lmd_pressure',
|
||
'lmd_winddirection', 'lmd_windspeed', 'power']
|
||
columns_match = list(data.columns) == check_list
|
||
if columns_match:
|
||
if length > guangfufadian_args.in_len:
|
||
input_data = raw_data[:guangfufadian_args.in_len, 1:].astype(np.float32)
|
||
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) # 增加batch维度
|
||
else:
|
||
rows_to_add = guangfufadian_args.in_len - length
|
||
first_row = raw_data[0]
|
||
new_rows = [first_row] * rows_to_add # 复制第一行的值
|
||
# 拼接新的行和原始数据
|
||
input_data = new_rows + list(raw_data)
|
||
input_data = np.array(input_data)
|
||
input_data = input_data[:guangfufadian_args.in_len, 1:].astype(np.float32)
|
||
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) # 增加batch维度
|
||
return {'status':True, 'reason':input_tensor }
|
||
else:
|
||
print("文件不匹配,请检查上传文件与模版是否一致")
|
||
return {'status':False, 'reason':'文件不匹配,请检查上传文件与模版是否一致'}
|
||
|
||
# 数据归一化
|
||
def preprocess_data(self, input_tensor):
|
||
input_tensor = self.scaler.transform(input_tensor)
|
||
return input_tensor.to(self.device)
|
||
# 数据预测
|
||
def predict(self, input_tensor):
|
||
with torch.no_grad(): # 不计算梯度
|
||
predictions = self.model_experiment._predict_batch(input_tensor)
|
||
# return predictions.cpu().numpy() # 返回 NumPy 数组
|
||
return predictions
|
||
# 推理过程
|
||
def run_inference(self, data):
|
||
raw_data = self.load_pred_data(data)
|
||
if raw_data['status'] == False:
|
||
return {'status':False, 'reason':'文件不匹配,请检查上传文件与模版是否一致'}
|
||
else:
|
||
input_tensor = self.preprocess_data(raw_data['reason'])
|
||
predictions = self.predict(input_tensor)
|
||
predictions = self.scaler.inverse_transform(predictions)
|
||
predictions = predictions.squeeze(0)
|
||
predictions = predictions.cpu().numpy()
|
||
return {'status':True, 'reason':predictions[:,-1:]} |