ai-station-code/fenglifadian/cross_models/cross_embed.py

22 lines
912 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math
class DSW_embedding(nn.Module):
def __init__(self, seg_len, d_model): # 6,256
super(DSW_embedding, self).__init__()
self.seg_len = seg_len
self.linear = nn.Linear(seg_len, d_model)
def forward(self, x):
batch, ts_len, ts_dim = x.shape # 32,168,7
# 32 28 6 7 32*7*28
x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len) # [32, 168, 7] → [6272, 6]
x_embed = self.linear(x_segment) # [6272, 6] → [6272, 256]
x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim) # [6272, 256] -> [32,7,28,256] batch_size = 32; feature = 7, length=29,embedding = 256
return x_embed