118 lines
6.8 KiB
Python
118 lines
6.8 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange, repeat
|
||
import numpy as np
|
||
|
||
from math import sqrt
|
||
|
||
class FullAttention(nn.Module):
|
||
'''
|
||
The Attention operation
|
||
'''
|
||
def __init__(self, scale=None, attention_dropout=0.1):
|
||
super(FullAttention, self).__init__()
|
||
self.scale = scale # attention计算时,缩放,公式中的根号下dk
|
||
self.dropout = nn.Dropout(attention_dropout)
|
||
|
||
def forward(self, queries, keys, values):
|
||
B, L, H, E = queries.shape
|
||
_, S, _, D = values.shape
|
||
scale = self.scale or 1./sqrt(E)
|
||
# qkv的计算
|
||
scores = torch.einsum("blhe,bshe->bhls", queries, keys) # 查询序列长度L, 键序列S
|
||
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # 这里训练的时候就drop呀
|
||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||
|
||
return V.contiguous() # 将v放在连续内存中
|
||
|
||
|
||
class AttentionLayer(nn.Module):
|
||
'''
|
||
The Multi-head Self-Attention (MSA) Layer
|
||
'''
|
||
def __init__(self, d_model, n_heads, d_keys=None, d_values=None, dropout = 0.1):
|
||
super(AttentionLayer, self).__init__()
|
||
|
||
d_keys = d_keys or (d_model//n_heads)
|
||
d_values = d_values or (d_model//n_heads)
|
||
|
||
self.inner_attention = FullAttention(scale=None, attention_dropout = dropout) # 全链接层
|
||
self.query_projection = nn.Linear(d_model, d_keys * n_heads) # q
|
||
self.key_projection = nn.Linear(d_model, d_keys * n_heads) # k
|
||
self.value_projection = nn.Linear(d_model, d_values * n_heads) # v
|
||
self.out_projection = nn.Linear(d_values * n_heads, d_model) # 输出层
|
||
self.n_heads = n_heads
|
||
|
||
def forward(self, queries, keys, values): # 原始输入作为qkv
|
||
B, L, _ = queries.shape
|
||
_, S, _ = keys.shape
|
||
H = self.n_heads
|
||
# 获取qkv矩阵
|
||
queries = self.query_projection(queries).view(B, L, H, -1) # queries(224,28,256) -> (224,28,256) -> (224,28,4,64)
|
||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||
values = self.value_projection(values).view(B, S, H, -1)
|
||
|
||
out = self.inner_attention(
|
||
queries,
|
||
keys,
|
||
values,
|
||
)
|
||
|
||
out = out.view(B, L, -1) # (224,28,256)
|
||
|
||
return self.out_projection(out)
|
||
|
||
class TwoStageAttentionLayer(nn.Module):
|
||
'''
|
||
The Two Stage Attention (TSA) Layer
|
||
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
|
||
'''
|
||
def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
|
||
super(TwoStageAttentionLayer, self).__init__()
|
||
d_ff = d_ff or 4*d_model # 如果变量 d_ff 的值为 None 或等价于布尔值 False(例如未显式指定),则将 d_ff 设置为 4 * d_model。 如果 d_ff 已经有定义(即非 None 且布尔值为 True),那么 d_ff 保持原值
|
||
self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout) # 用于时间维度上的注意力计算,捕捉时间序列数据中的时间依赖性
|
||
self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout) # 在注意力机制中,dim_sender 是发送方特征的维度,用于生成查询(Query)或键(Key)
|
||
self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout) # 在注意力机制中,dim_receiver 是接收方特征的维度,用于生成值(Value)或接收注意力权重。
|
||
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) # (28,10,256)在复杂注意力机制中,router 帮助模型动态分配注意力资源,提升对不同特征的关注能力。
|
||
# nn.Parameter 是一个特殊的张量类型,用于定义可以被优化器(如 SGD、Adam 等)更新的模型参数
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
self.norm3 = nn.LayerNorm(d_model)
|
||
self.norm4 = nn.LayerNorm(d_model)
|
||
|
||
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), #FFN模块;
|
||
nn.GELU(),
|
||
nn.Linear(d_ff, d_model))
|
||
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
|
||
nn.GELU(),
|
||
nn.Linear(d_ff, d_model))
|
||
|
||
def forward(self, x):
|
||
#Cross Time Stage: Directly apply MSA to each dimension
|
||
batch = x.shape[0]
|
||
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') # 224,28,256 , 其中 224 = 32* 7
|
||
time_enc = self.time_attention(
|
||
time_in, time_in, time_in
|
||
)
|
||
dim_in = time_in + self.dropout(time_enc) # 残差
|
||
dim_in = self.norm1(dim_in) # 层归一化
|
||
dim_in = dim_in + self.dropout(self.MLP1(dim_in)) # FNN + 残差
|
||
dim_in = self.norm2(dim_in) # 层归一化
|
||
|
||
#Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
|
||
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch) # (224,28,256) -> (896,7,256) 这样的重排使得每个批次的每个分段(segment)可以独立处理。通过将批量和分段数量结合在一起,后续的操作可以更方便地进行并行计算和消息传递。
|
||
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch) # (28,10,256) => (896,10,256) 28*32 = 896 , 通过重复路由器参数,确保每个批次都使用相同的路由器设置。这种方式可以保持模型的一致性,并减少模型参数的数量,同时允许每个批次的数据使用相同的学习策略。
|
||
dim_buffer = self.dim_sender(batch_router, dim_send, dim_send) # 通过路由器将输入数据(dim_send)进行处理,生成一个中间的缓冲区(dim_buffer)。
|
||
dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer) # 接收来自 dim_sender 的消息,并结合原始输入(dim_send)和缓冲区(dim_buffer)来生成接收的输出(dim_receive)
|
||
dim_enc = dim_send + self.dropout(dim_receive) # 使用残差连接(dim_send + self.dropout(dim_receive))来结合原始输入和接收的消息。这种方法有助于缓解深度网络中的梯度消失问题,并保持信息流动
|
||
dim_enc = self.norm3(dim_enc) # 来标准化输出,进一步稳定训练过程
|
||
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) # 通过多层感知机(MLP)对编码后的信息进行进一步的变换和学习。再次使用残差连接将 MLP 的输出与 dim_enc 结合
|
||
dim_enc = self.norm4(dim_enc)
|
||
|
||
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)
|
||
|
||
return final_out
|