import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat import math class DSW_embedding(nn.Module): def __init__(self, seg_len, d_model): # 6,256 super(DSW_embedding, self).__init__() self.seg_len = seg_len self.linear = nn.Linear(seg_len, d_model) def forward(self, x): batch, ts_len, ts_dim = x.shape # 32,168,7 # 32 28 6 7 32*7*28 x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len) # [32, 168, 7] → [6272, 6] x_embed = self.linear(x_segment) # [6272, 6] → [6272, 256] x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim) # [6272, 256] -> [32,7,28,256] batch_size = 32; feature = 7, length=29,embedding = 256 return x_embed