ai-station-code/guangfufadian/cross_models/cross_former.py

81 lines
3.7 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from ..cross_models.cross_encoder import Encoder
from ..cross_models.cross_decoder import Decoder
from ..cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
from ..cross_models.cross_embed import DSW_embedding
from math import ceil
class Crossformer(nn.Module):
def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3,
dropout=0.0, baseline = False, device=torch.device('cuda:0')):
super(Crossformer, self).__init__()
"""
self.args.data_dim, # 7
self.args.in_len, # 96
self.args.out_len, # 24
self.args.seg_len, # 6
self.args.win_size, # 2
self.args.factor, # 10
self.args.d_model, # 256
self.args.d_ff, # 512
self.args.n_heads, # 4
self.args.e_layers, # 3
self.args.dropout, # 0.2
self.args.baseline, # True
self.device
"""
self.data_dim = data_dim
self.in_len = in_len
self.out_len = out_len
self.seg_len = seg_len
self.merge_win = win_size
self.baseline = baseline
self.device = device
# The padding operation to handle invisible sgemnet length 输入 168输出24填充补全维度0
self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len
self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len
self.in_len_add = self.pad_in_len - self.in_len
# Embedding
self.enc_value_embedding = DSW_embedding(seg_len, d_model) # [3584, 256] -> [32,7,16,256]
self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model)) # 1表示这个嵌入是共享的适用于所有批次 ; [1,7,28,256],随机数据正态分布
self.pre_norm = nn.LayerNorm(d_model) # 256维度进行层归一化也就是样本归一化
# Encoder
self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \
dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor)
# Decoder
self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model))
self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
out_seg_num = (self.pad_out_len // seg_len), factor = factor)
def forward(self, x_seq):
if (self.baseline):
base = x_seq.mean(dim = 1, keepdim = True) # 取168步长的特征均值
else:
base = 0
batch_size = x_seq.shape[0]
if (self.in_len_add != 0): # 不等于0的时候用于拼接这里等于0不需要
x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1) # 拼接的话,也是在最前面进行复制拼接,保证输入数据维度正常
x_seq = self.enc_value_embedding(x_seq) # [32,168,7] -> [32,7,28,256]
x_seq += self.enc_pos_embedding # enc_pos_embedding 【1,7,28,256】 # 加上位置编码
x_seq = self.pre_norm(x_seq) # 层归一化
enc_out = self.encoder(x_seq)
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size) # encoderdecoder的不是一个矩阵
predict_y = self.decoder(dec_in, enc_out)
return base + predict_y[:, :self.out_len, :]