Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/OutlookAttention.py

67 lines
2.6 KiB
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
import numpy as np
import torch
from torch import nn
from torch.nn import init
import math
from torch.nn import functional as F
class OutlookAttention(nn.Module):
def __init__(self, dim, num_heads=1, kernel_size=3, padding=1, stride=1, qkv_bias=False,
attn_drop=0.1):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.scale = self.head_dim ** (-0.5)
self.v_pj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(attn_drop)
self.unflod = nn.Unfold(kernel_size, padding, stride) # 手动卷积
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
def forward(self, x):
B, H, W, C = x.shape
# 映射到新的特征v
v = self.v_pj(x).permute(0, 3, 1, 2) # B,C,H,W
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
v = self.unflod(v).reshape(B, self.num_heads, self.head_dim, self.kernel_size * self.kernel_size,
h * w).permute(0, 1, 4, 3, 2) # B,num_head,H*W,kxk,head_dim
# 生成Attention Map
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # B,H,W,C
attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size \
, self.kernel_size * self.kernel_size).permute(0, 2, 1, 3,
4) # Bnum_headH*W,kxk,kxk
attn = self.scale * attn
attn = attn.softmax(-1)
attn = self.attn_drop(attn)
# 获取weighted特征
out = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size,
h * w) # B,dimxkxk,H*W
out = F.fold(out, output_size=(H, W), kernel_size=self.kernel_size,
padding=self.padding, stride=self.stride) # B,C,H,W
out = self.proj(out.permute(0, 2, 3, 1)) # B,H,W,C
out = self.proj_drop(out)
return out
# 输入 B, H, W, C, 输出 B, H, W, C
if __name__ == '__main__':
block = OutlookAttention(dim=512).cuda()
input = torch.rand(1, 64, 64, 512).cuda()
output = block(input)
print(input.size(), output.size())