ai-station-code/guangfufadian/model_base.py

111 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from guangfufadian.cross_exp.exp_crossformer import Exp_crossformer # 根据实际路径导入
from guangfufadian.utils.tools import StandardScaler
import pickle,os,sys
import torch
import pandas as pd
import torch
import numpy as np
# 获取当前工作目录
current_directory = os.getcwd()
current_directory = os.path.join(current_directory,'guangfufadian')
# 设置参数
class guangfufadian_Args:
# 在这里定义模型参数
data_dim = 14
in_len = 192
out_len = 96
seg_len = 6
win_size = 2
factor = 10
d_model = 256
d_ff = 512
n_heads = 4
e_layers = 3
dropout = 0.2
use_multi_gpu = False
use_gpu = True
device_ids = [0] # 如果使用多 GPU设置设备 ID
batch_size = 32
train_epochs = 100
patience = 10
num_workers = 4
# root_path = '/home/xiazj/project_test/Crossformer-master/datasets/station08_utf8.csv' # 数据路径
data_path = os.path.join(current_directory,'datasets/station08_utf8.csv') # 数据文件名
checkpoints = os.path.join(current_directory,'checkpoints') # 模型保存路径
learning_rate = 0.001 # 学习率
data_split = 0.8 # 数据划分比例
scale_statistic = pickle.load(open(os.path.join(checkpoints,'Crossformer_station08_il192_ol96_sl6_win2_fa10_dm256_nh4_el3_itr0/scale_statistic.pkl'), 'rb'))
baseline = False
gpu = 0
# 实例化参数
guangfufadian_args = guangfufadian_Args()
class ModelInference:
def __init__(self, model_path, args):
# 初始化模型
self.model_experiment = Exp_crossformer(args)
self.model_experiment.model.load_state_dict(torch.load(model_path)) # 加载模型参数
self.model_experiment.model.eval() # 设置模型为评估模式
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_experiment.model.to(self.device)
self.scaler = StandardScaler(mean = args.scale_statistic['mean'], std = args.scale_statistic['std'])
"""
加载推理数据
如果数据列发生变化,或者列不匹配,报错
数据长度满足192截取前192
数据长度不满足192前拼接满足数量级
"""
def load_pred_data(self,data):
raw_data = data.values
length = raw_data.shape[0]
# 顺序,必须一致才能进行下一步
check_list = ['date_time', 'nwp_globalirrad', 'nwp_directirrad', 'nwp_temperature',
'nwp_humidity', 'nwp_windspeed', 'nwp_winddirection', 'nwp_pressure',
'lmd_totalirrad', 'lmd_diffuseirrad', 'lmd_temperature', 'lmd_pressure',
'lmd_winddirection', 'lmd_windspeed', 'power']
columns_match = list(data.columns) == check_list
if columns_match:
if length > guangfufadian_args.in_len:
input_data = raw_data[:guangfufadian_args.in_len, 1:].astype(np.float32)
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) # 增加batch维度
else:
rows_to_add = guangfufadian_args.in_len - length
first_row = raw_data[0]
new_rows = [first_row] * rows_to_add # 复制第一行的值
# 拼接新的行和原始数据
input_data = new_rows + list(raw_data)
input_data = np.array(input_data)
input_data = input_data[:guangfufadian_args.in_len, 1:].astype(np.float32)
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) # 增加batch维度
return {'status':True, 'reason':input_tensor }
else:
print("文件不匹配,请检查上传文件与模版是否一致")
return {'status':False, 'reason':'文件不匹配,请检查上传文件与模版是否一致'}
# 数据归一化
def preprocess_data(self, input_tensor):
input_tensor = self.scaler.transform(input_tensor)
return input_tensor.to(self.device)
# 数据预测
def predict(self, input_tensor):
with torch.no_grad(): # 不计算梯度
predictions = self.model_experiment._predict_batch(input_tensor)
# return predictions.cpu().numpy() # 返回 NumPy 数组
return predictions
# 推理过程
def run_inference(self, data):
raw_data = self.load_pred_data(data)
if raw_data['status'] == False:
return {'status':False, 'reason':'文件不匹配,请检查上传文件与模版是否一致'}
else:
input_tensor = self.preprocess_data(raw_data['reason'])
predictions = self.predict(input_tensor)
predictions = self.scaler.inverse_transform(predictions)
predictions = predictions.squeeze(0)
predictions = predictions.cpu().numpy()
return {'status':True, 'reason':predictions[:,-1:]}