import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from fenglifadian.cross_models.cross_encoder import Encoder from fenglifadian.cross_models.cross_decoder import Decoder from fenglifadian.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer from fenglifadian.cross_models.cross_embed import DSW_embedding from math import ceil class Crossformer(nn.Module): def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4, factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3, dropout=0.0, baseline = False, device=torch.device('cuda:0')): super(Crossformer, self).__init__() """ self.args.data_dim, # 7 self.args.in_len, # 96 self.args.out_len, # 24 self.args.seg_len, # 6 self.args.win_size, # 2 self.args.factor, # 10 self.args.d_model, # 256 self.args.d_ff, # 512 self.args.n_heads, # 4 self.args.e_layers, # 3 self.args.dropout, # 0.2 self.args.baseline, # True self.device """ self.data_dim = data_dim self.in_len = in_len self.out_len = out_len self.seg_len = seg_len self.merge_win = win_size self.baseline = baseline self.device = device # The padding operation to handle invisible sgemnet length 输入 168,输出24,填充补全维度0 self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len self.in_len_add = self.pad_in_len - self.in_len # Embedding self.enc_value_embedding = DSW_embedding(seg_len, d_model) # [3584, 256] -> [32,7,16,256] self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model)) # 1:表示这个嵌入是共享的,适用于所有批次 ; [1,7,28,256],随机数据正态分布 self.pre_norm = nn.LayerNorm(d_model) # 256维度进行层归一化,也就是样本归一化 # Encoder self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \ dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor) # Decoder self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model)) self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \ out_seg_num = (self.pad_out_len // seg_len), factor = factor) def forward(self, x_seq): if (self.baseline): base = x_seq.mean(dim = 1, keepdim = True) # 取168步长的特征均值 else: base = 0 batch_size = x_seq.shape[0] if (self.in_len_add != 0): # 不等于0的时候,用于拼接,这里等于0,不需要 x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1) # 拼接的话,也是在最前面进行复制拼接,保证输入数据维度正常 x_seq = self.enc_value_embedding(x_seq) # [32,168,7] -> [32,7,28,256] x_seq += self.enc_pos_embedding # enc_pos_embedding 【1,7,28,256】 # 加上位置编码 x_seq = self.pre_norm(x_seq) # 层归一化 enc_out = self.encoder(x_seq) dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size) # encoder,decoder的不是一个矩阵 predict_y = self.decoder(dec_in, enc_out) return base + predict_y[:, :self.out_len, :]