ai-station-code/guangfufadian/cross_models/cross_encoder.py

106 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from ..cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
from math import ceil
class SegMerging(nn.Module):
'''
Segment Merging Layer.
The adjacent `win_size' segments in each dimension will be merged into one segment to
get representation of a coarser scale
we set win_size = 2 in our paper 在每个维度中,相邻的 win_size 个段segments将被合并为一个段以获得更粗粒度的表示。
'''
def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
super().__init__()
self.d_model = d_model
self.win_size = win_size
self.linear_trans = nn.Linear(win_size * d_model, d_model) # 输入维度是 win_size * d_model输出维度是 d_model
self.norm = norm_layer(win_size * d_model) # 指定维度上层归一化
def forward(self, x):
"""
x: B, ts_d, L, d_model
"""
batch_size, ts_d, seg_num, d_model = x.shape
pad_num = seg_num % self.win_size
if pad_num != 0:
pad_num = self.win_size - pad_num
x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2) # L维度进行拼接满足整除条件
seg_to_merge = []
for i in range(self.win_size):
seg_to_merge.append(x[:, :, i::self.win_size, :]) # 第三维度拆分为2个矩阵
x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model] # 保留第三维度,将数据都拼接在第四维度
x = self.norm(x)
x = self.linear_trans(x) # 线性,数据量减半
return x
class scale_block(nn.Module):
'''
We can use one segment merging layer followed by multiple TSA layers in each scale
the parameter `depth' determines the number of TSA layers used in each scale
We set depth = 1 in the paper
'''# # win_size = 1 or 2 seg_num = 28 depth = 1, factor = 10
def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \
seg_num = 10, factor=10):
super(scale_block, self).__init__()
if (win_size > 1):
self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
else:
self.merge_layer = None
self.encode_layers = nn.ModuleList()
for i in range(depth):
self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \
d_ff, dropout))
def forward(self, x):
_, ts_dim, _, _ = x.shape
if self.merge_layer is not None: # 如果不是第一层则需要Segment Merging Layer.进行数据降维
x = self.merge_layer(x)
for layer in self.encode_layers:
x = layer(x)
return x
class Encoder(nn.Module):
'''
The Encoder of Crossformer.
e_blocks, 3
win_size, 2
d_model, 256
n_heads, 4
d_ff, 512
block_depth, 1
dropout, 0.2
in_seg_num = 10, 16
factor=10 ,10
'''
def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout,
in_seg_num = 10, factor=10):
super(Encoder, self).__init__()
self.encode_blocks = nn.ModuleList()
self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\
in_seg_num, factor))
for i in range(1, e_blocks):
self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\
ceil(in_seg_num/win_size**i), factor))
def forward(self, x): # x [32,7,28,256]
encode_x = []
encode_x.append(x)
for block in self.encode_blocks: # self.encode_blocks 总共三层
x = block(x)
encode_x.append(x)
return encode_x