ai-station-code/fenglifadian/cross_models/cross_encoder.py

106 lines
4.0 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from fenglifadian.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
from math import ceil
class SegMerging(nn.Module):
'''
Segment Merging Layer.
The adjacent `win_size' segments in each dimension will be merged into one segment to
get representation of a coarser scale
we set win_size = 2 in our paper 在每个维度中相邻的 win_size 个段segments将被合并为一个段以获得更粗粒度的表示
'''
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
super().__init__()
self.d_model = d_model
self.win_size = win_size
self.linear_trans = nn.Linear(win_size * d_model, d_model) # 输入维度是 win_size * d_model输出维度是 d_model
self.norm = norm_layer(win_size * d_model) # 指定维度上层归一化
def forward(self, x):
"""
x: B, ts_d, L, d_model
"""
batch_size, ts_d, seg_num, d_model = x.shape
pad_num = seg_num % self.win_size
if pad_num != 0:
pad_num = self.win_size - pad_num
x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2) # L维度进行拼接满足整除条件
seg_to_merge = []
for i in range(self.win_size):
seg_to_merge.append(x[:, :, i::self.win_size, :]) # 第三维度拆分为2个矩阵
x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model] # 保留第三维度,将数据都拼接在第四维度
x = self.norm(x)
x = self.linear_trans(x) # 线性,数据量减半
return x
class scale_block(nn.Module):
'''
We can use one segment merging layer followed by multiple TSA layers in each scale
the parameter `depth' determines the number of TSA layers used in each scale
We set depth = 1 in the paper
'''# # win_size = 1 or 2 seg_num = 28 depth = 1, factor = 10
def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \
seg_num = 10, factor=10):
super(scale_block, self).__init__()
if (win_size > 1):
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
else:
self.merge_layer = None
self.encode_layers = nn.ModuleList()
for i in range(depth):
self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \
d_ff, dropout))
def forward(self, x):
_, ts_dim, _, _ = x.shape
if self.merge_layer is not None: # 如果不是第一层则需要Segment Merging Layer.进行数据降维
x = self.merge_layer(x)
for layer in self.encode_layers:
x = layer(x)
return x
class Encoder(nn.Module):
'''
The Encoder of Crossformer.
e_blocks, 3
win_size, 2
d_model, 256
n_heads, 4
d_ff, 512
block_depth, 1
dropout, 0.2
in_seg_num = 10, 16
factor=10 ,10
'''
def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout,
in_seg_num = 10, factor=10):
super(Encoder, self).__init__()
self.encode_blocks = nn.ModuleList()
self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\
in_seg_num, factor))
for i in range(1, e_blocks):
self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\
ceil(in_seg_num/win_size**i), factor))
def forward(self, x): # x [32,7,28,256]
encode_x = []
encode_x.append(x)
for block in self.encode_blocks: # self.encode_blocks 总共三层
x = block(x)
encode_x.append(x)
return encode_x