ai-station-code/guangfufadian/cross_models/attn.py

118 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from math import sqrt
class FullAttention(nn.Module):
'''
The Attention operation
'''
def __init__(self, scale=None, attention_dropout=0.1):
super(FullAttention, self).__init__()
self.scale = scale # attention计算时缩放公式中的根号下dk
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1./sqrt(E)
# qkv的计算
scores = torch.einsum("blhe,bshe->bhls", queries, keys) # 查询序列长度L 键序列S
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # 这里训练的时候就drop呀
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous() # 将v放在连续内存中
class AttentionLayer(nn.Module):
'''
The Multi-head Self-Attention (MSA) Layer
'''
def __init__(self, d_model, n_heads, d_keys=None, d_values=None, dropout = 0.1):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model//n_heads)
d_values = d_values or (d_model//n_heads)
self.inner_attention = FullAttention(scale=None, attention_dropout = dropout) # 全链接层
self.query_projection = nn.Linear(d_model, d_keys * n_heads) # q
self.key_projection = nn.Linear(d_model, d_keys * n_heads) # k
self.value_projection = nn.Linear(d_model, d_values * n_heads) # v
self.out_projection = nn.Linear(d_values * n_heads, d_model) # 输出层
self.n_heads = n_heads
def forward(self, queries, keys, values): # 原始输入作为qkv
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
# 获取qkv矩阵
queries = self.query_projection(queries).view(B, L, H, -1) # queries224,28,256 -> (224,28,256) -> (224,28,4,64)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out = self.inner_attention(
queries,
keys,
values,
)
out = out.view(B, L, -1) # 224,28,256
return self.out_projection(out)
class TwoStageAttentionLayer(nn.Module):
'''
The Two Stage Attention (TSA) Layer
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
'''
def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
super(TwoStageAttentionLayer, self).__init__()
d_ff = d_ff or 4*d_model # 如果变量 d_ff 的值为 None 或等价于布尔值 False例如未显式指定则将 d_ff 设置为 4 * d_model。 如果 d_ff 已经有定义(即非 None 且布尔值为 True那么 d_ff 保持原值
self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout) # 用于时间维度上的注意力计算,捕捉时间序列数据中的时间依赖性
self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout) # 在注意力机制中dim_sender 是发送方特征的维度用于生成查询Query或键Key
self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout) # 在注意力机制中dim_receiver 是接收方特征的维度用于生成值Value或接收注意力权重。
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) # (28,10,256)在复杂注意力机制中router 帮助模型动态分配注意力资源,提升对不同特征的关注能力。
# nn.Parameter 是一个特殊的张量类型,用于定义可以被优化器(如 SGD、Adam 等)更新的模型参数
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), #FFN模块
nn.GELU(),
nn.Linear(d_ff, d_model))
self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model))
def forward(self, x):
#Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0]
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') # 224,28,256 , 其中 224 = 32* 7
time_enc = self.time_attention(
time_in, time_in, time_in
)
dim_in = time_in + self.dropout(time_enc) # 残差
dim_in = self.norm1(dim_in) # 层归一化
dim_in = dim_in + self.dropout(self.MLP1(dim_in)) # FNN + 残差
dim_in = self.norm2(dim_in) # 层归一化
#Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch) # (224,28,256) -> (896,7,256) 这样的重排使得每个批次的每个分段segment可以独立处理。通过将批量和分段数量结合在一起后续的操作可以更方便地进行并行计算和消息传递。
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch) # (28,10,256) => (896,10,256) 28*32 = 896 , 通过重复路由器参数,确保每个批次都使用相同的路由器设置。这种方式可以保持模型的一致性,并减少模型参数的数量,同时允许每个批次的数据使用相同的学习策略。
dim_buffer = self.dim_sender(batch_router, dim_send, dim_send) # 通过路由器将输入数据dim_send进行处理生成一个中间的缓冲区dim_buffer
dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer) # 接收来自 dim_sender 的消息并结合原始输入dim_send和缓冲区dim_buffer来生成接收的输出dim_receive
dim_enc = dim_send + self.dropout(dim_receive) # 使用残差连接dim_send + self.dropout(dim_receive))来结合原始输入和接收的消息。这种方法有助于缓解深度网络中的梯度消失问题,并保持信息流动
dim_enc = self.norm3(dim_enc) # 来标准化输出,进一步稳定训练过程
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) # 通过多层感知机MLP对编码后的信息进行进一步的变换和学习。再次使用残差连接将 MLP 的输出与 dim_enc 结合
dim_enc = self.norm4(dim_enc)
final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)
return final_out