81 lines
3.8 KiB
Python
81 lines
3.8 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange, repeat
|
||
|
||
from fenglifadian.cross_models.cross_encoder import Encoder
|
||
from fenglifadian.cross_models.cross_decoder import Decoder
|
||
from fenglifadian.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
|
||
from fenglifadian.cross_models.cross_embed import DSW_embedding
|
||
|
||
from math import ceil
|
||
|
||
class Crossformer(nn.Module):
|
||
def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
|
||
factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3,
|
||
dropout=0.0, baseline = False, device=torch.device('cuda:0')):
|
||
super(Crossformer, self).__init__()
|
||
"""
|
||
self.args.data_dim, # 7
|
||
self.args.in_len, # 96
|
||
self.args.out_len, # 24
|
||
self.args.seg_len, # 6
|
||
self.args.win_size, # 2
|
||
self.args.factor, # 10
|
||
self.args.d_model, # 256
|
||
self.args.d_ff, # 512
|
||
self.args.n_heads, # 4
|
||
self.args.e_layers, # 3
|
||
self.args.dropout, # 0.2
|
||
self.args.baseline, # True
|
||
self.device
|
||
"""
|
||
self.data_dim = data_dim
|
||
self.in_len = in_len
|
||
self.out_len = out_len
|
||
self.seg_len = seg_len
|
||
self.merge_win = win_size
|
||
|
||
self.baseline = baseline
|
||
|
||
self.device = device
|
||
|
||
# The padding operation to handle invisible sgemnet length 输入 168,输出24,填充补全维度0
|
||
self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len
|
||
self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len
|
||
self.in_len_add = self.pad_in_len - self.in_len
|
||
|
||
# Embedding
|
||
self.enc_value_embedding = DSW_embedding(seg_len, d_model) # [3584, 256] -> [32,7,16,256]
|
||
self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model)) # 1:表示这个嵌入是共享的,适用于所有批次 ; [1,7,28,256],随机数据正态分布
|
||
self.pre_norm = nn.LayerNorm(d_model) # 256维度进行层归一化,也就是样本归一化
|
||
|
||
# Encoder
|
||
self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \
|
||
dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor)
|
||
|
||
# Decoder
|
||
self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model))
|
||
self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
|
||
out_seg_num = (self.pad_out_len // seg_len), factor = factor)
|
||
|
||
def forward(self, x_seq):
|
||
if (self.baseline):
|
||
base = x_seq.mean(dim = 1, keepdim = True) # 取168步长的特征均值
|
||
else:
|
||
base = 0
|
||
batch_size = x_seq.shape[0]
|
||
if (self.in_len_add != 0): # 不等于0的时候,用于拼接,这里等于0,不需要
|
||
x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1) # 拼接的话,也是在最前面进行复制拼接,保证输入数据维度正常
|
||
|
||
x_seq = self.enc_value_embedding(x_seq) # [32,168,7] -> [32,7,28,256]
|
||
x_seq += self.enc_pos_embedding # enc_pos_embedding 【1,7,28,256】 # 加上位置编码
|
||
x_seq = self.pre_norm(x_seq) # 层归一化
|
||
|
||
enc_out = self.encoder(x_seq)
|
||
|
||
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size) # encoder,decoder的不是一个矩阵
|
||
predict_y = self.decoder(dec_in, enc_out)
|
||
|
||
|
||
return base + predict_y[:, :self.out_len, :] |