22 lines
912 B
Python
22 lines
912 B
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from einops import rearrange, repeat
|
||
|
|
||
|
import math
|
||
|
|
||
|
class DSW_embedding(nn.Module):
|
||
|
def __init__(self, seg_len, d_model): # 6,256
|
||
|
super(DSW_embedding, self).__init__()
|
||
|
self.seg_len = seg_len
|
||
|
|
||
|
self.linear = nn.Linear(seg_len, d_model)
|
||
|
|
||
|
def forward(self, x):
|
||
|
batch, ts_len, ts_dim = x.shape # 32,168,7
|
||
|
# 32 28 6 7 32*7*28
|
||
|
x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len) # [32, 168, 7] → [6272, 6]
|
||
|
x_embed = self.linear(x_segment) # [6272, 6] → [6272, 256]
|
||
|
x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim) # [6272, 256] -> [32,7,28,256] batch_size = 32; feature = 7, length=29,embedding = 256
|
||
|
|
||
|
return x_embed
|