ai-station-code/guangfufadian/cross_models/cross_former.py

81 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from ..cross_models.cross_encoder import Encoder
from ..cross_models.cross_decoder import Decoder
from ..cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
from ..cross_models.cross_embed import DSW_embedding
from math import ceil
class Crossformer(nn.Module):
def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3,
dropout=0.0, baseline = False, device=torch.device('cuda:0')):
super(Crossformer, self).__init__()
"""
self.args.data_dim, # 7
self.args.in_len, # 96
self.args.out_len, # 24
self.args.seg_len, # 6
self.args.win_size, # 2
self.args.factor, # 10
self.args.d_model, # 256
self.args.d_ff, # 512
self.args.n_heads, # 4
self.args.e_layers, # 3
self.args.dropout, # 0.2
self.args.baseline, # True
self.device
"""
self.data_dim = data_dim
self.in_len = in_len
self.out_len = out_len
self.seg_len = seg_len
self.merge_win = win_size
self.baseline = baseline
self.device = device
# The padding operation to handle invisible sgemnet length 输入 168输出24填充补全维度0
self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len
self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len
self.in_len_add = self.pad_in_len - self.in_len
# Embedding
self.enc_value_embedding = DSW_embedding(seg_len, d_model) # [3584, 256] -> [32,7,16,256]
self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model)) # 1表示这个嵌入是共享的适用于所有批次 ; [1,7,28,256],随机数据正态分布
self.pre_norm = nn.LayerNorm(d_model) # 256维度进行层归一化也就是样本归一化
# Encoder
self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \
dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor)
# Decoder
self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model))
self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
out_seg_num = (self.pad_out_len // seg_len), factor = factor)
def forward(self, x_seq):
if (self.baseline):
base = x_seq.mean(dim = 1, keepdim = True) # 取168步长的特征均值
else:
base = 0
batch_size = x_seq.shape[0]
if (self.in_len_add != 0): # 不等于0的时候用于拼接这里等于0不需要
x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1) # 拼接的话,也是在最前面进行复制拼接,保证输入数据维度正常
x_seq = self.enc_value_embedding(x_seq) # [32,168,7] -> [32,7,28,256]
x_seq += self.enc_pos_embedding # enc_pos_embedding 【1,7,28,256】 # 加上位置编码
x_seq = self.pre_norm(x_seq) # 层归一化
enc_out = self.encoder(x_seq)
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size) # encoderdecoder的不是一个矩阵
predict_y = self.decoder(dec_in, enc_out)
return base + predict_y[:, :self.out_len, :]