import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from ..cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer from math import ceil class SegMerging(nn.Module): ''' Segment Merging Layer. The adjacent `win_size' segments in each dimension will be merged into one segment to get representation of a coarser scale we set win_size = 2 in our paper 在每个维度中,相邻的 win_size 个段(segments)将被合并为一个段,以获得更粗粒度的表示。 ''' def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): super().__init__() self.d_model = d_model self.win_size = win_size self.linear_trans = nn.Linear(win_size * d_model, d_model) # 输入维度是 win_size * d_model,输出维度是 d_model self.norm = norm_layer(win_size * d_model) # 指定维度上层归一化 def forward(self, x): """ x: B, ts_d, L, d_model """ batch_size, ts_d, seg_num, d_model = x.shape pad_num = seg_num % self.win_size if pad_num != 0: pad_num = self.win_size - pad_num x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2) # L维度进行拼接,满足整除条件 seg_to_merge = [] for i in range(self.win_size): seg_to_merge.append(x[:, :, i::self.win_size, :]) # 第三维度拆分为2个矩阵 x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model] # 保留第三维度,将数据都拼接在第四维度 x = self.norm(x) x = self.linear_trans(x) # 线性,数据量减半 return x class scale_block(nn.Module): ''' We can use one segment merging layer followed by multiple TSA layers in each scale the parameter `depth' determines the number of TSA layers used in each scale We set depth = 1 in the paper '''# # win_size = 1 or 2, seg_num = 28, depth = 1, factor = 10 def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \ seg_num = 10, factor=10): super(scale_block, self).__init__() if (win_size > 1): self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) else: self.merge_layer = None self.encode_layers = nn.ModuleList() for i in range(depth): self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \ d_ff, dropout)) def forward(self, x): _, ts_dim, _, _ = x.shape if self.merge_layer is not None: # 如果不是第一层,则需要Segment Merging Layer.进行数据降维 x = self.merge_layer(x) for layer in self.encode_layers: x = layer(x) return x class Encoder(nn.Module): ''' The Encoder of Crossformer. e_blocks, 3 win_size, 2 d_model, 256 n_heads, 4 d_ff, 512 block_depth, 1 dropout, 0.2 in_seg_num = 10, 16 factor=10 ,10 ''' def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout, in_seg_num = 10, factor=10): super(Encoder, self).__init__() self.encode_blocks = nn.ModuleList() self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\ in_seg_num, factor)) for i in range(1, e_blocks): self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\ ceil(in_seg_num/win_size**i), factor)) def forward(self, x): # x [32,7,28,256] encode_x = [] encode_x.append(x) for block in self.encode_blocks: # self.encode_blocks 总共三层 x = block(x) encode_x.append(x) return encode_x