81 lines
3.2 KiB
Python
81 lines
3.2 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from einops import rearrange, repeat
|
||
|
from fenglifadian.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
|
||
|
|
||
|
class DecoderLayer(nn.Module):
|
||
|
'''
|
||
|
The decoder layer of Crossformer, each layer will make a prediction at its scale
|
||
|
'''
|
||
|
def __init__(self, seg_len, d_model, n_heads, d_ff=None, dropout=0.1, out_seg_num = 10, factor = 10):
|
||
|
super(DecoderLayer, self).__init__()
|
||
|
self.self_attention = TwoStageAttentionLayer(out_seg_num, factor, d_model, n_heads, \
|
||
|
d_ff, dropout)
|
||
|
self.cross_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
|
||
|
self.norm1 = nn.LayerNorm(d_model)
|
||
|
self.norm2 = nn.LayerNorm(d_model)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(d_model, d_model))
|
||
|
self.linear_pred = nn.Linear(d_model, seg_len)
|
||
|
|
||
|
def forward(self, x, cross):
|
||
|
'''
|
||
|
x: the output of last decoder layer
|
||
|
cross: the output of the corresponding encoder layer
|
||
|
'''
|
||
|
|
||
|
batch = x.shape[0]
|
||
|
x = self.self_attention(x) # 进行一次TwoStageAttentionLayer
|
||
|
x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
|
||
|
|
||
|
cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
|
||
|
tmp = self.cross_attention(
|
||
|
x, cross, cross,
|
||
|
) # 与encoder内容进行attention
|
||
|
x = x + self.dropout(tmp)
|
||
|
y = x = self.norm1(x)
|
||
|
y = self.MLP1(y)
|
||
|
dec_output = self.norm2(x+y)
|
||
|
|
||
|
dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b = batch)
|
||
|
layer_predict = self.linear_pred(dec_output)
|
||
|
layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')
|
||
|
|
||
|
return dec_output, layer_predict
|
||
|
|
||
|
class Decoder(nn.Module):
|
||
|
'''
|
||
|
The decoder of Crossformer, making the final prediction by adding up predictions at each scale
|
||
|
'''
|
||
|
def __init__(self, seg_len, d_layers, d_model, n_heads, d_ff, dropout,\
|
||
|
router=False, out_seg_num = 10, factor=10):
|
||
|
super(Decoder, self).__init__()
|
||
|
|
||
|
self.router = router
|
||
|
self.decode_layers = nn.ModuleList()
|
||
|
for i in range(d_layers): # x 有四层
|
||
|
self.decode_layers.append(DecoderLayer(seg_len, d_model, n_heads, d_ff, dropout, \
|
||
|
out_seg_num, factor))
|
||
|
|
||
|
def forward(self, x, cross):
|
||
|
final_predict = None
|
||
|
i = 0
|
||
|
|
||
|
ts_d = x.shape[1]
|
||
|
for layer in self.decode_layers:
|
||
|
cross_enc = cross[i]
|
||
|
x, layer_predict = layer(x, cross_enc) #
|
||
|
if final_predict is None:
|
||
|
final_predict = layer_predict
|
||
|
else:
|
||
|
final_predict = final_predict + layer_predict
|
||
|
i += 1
|
||
|
|
||
|
final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d = ts_d)
|
||
|
|
||
|
return final_predict
|
||
|
|