67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
import numpy as np
|
||
import torch
|
||
from torch import nn
|
||
from torch.nn import init
|
||
import math
|
||
from torch.nn import functional as F
|
||
|
||
|
||
class OutlookAttention(nn.Module):
|
||
|
||
def __init__(self, dim, num_heads=1, kernel_size=3, padding=1, stride=1, qkv_bias=False,
|
||
attn_drop=0.1):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.num_heads = num_heads
|
||
self.head_dim = dim // num_heads
|
||
self.kernel_size = kernel_size
|
||
self.padding = padding
|
||
self.stride = stride
|
||
self.scale = self.head_dim ** (-0.5)
|
||
|
||
self.v_pj = nn.Linear(dim, dim, bias=qkv_bias)
|
||
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
|
||
|
||
self.attn_drop = nn.Dropout(attn_drop)
|
||
self.proj = nn.Linear(dim, dim)
|
||
self.proj_drop = nn.Dropout(attn_drop)
|
||
|
||
self.unflod = nn.Unfold(kernel_size, padding, stride) # 手动卷积
|
||
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
|
||
|
||
def forward(self, x):
|
||
B, H, W, C = x.shape
|
||
|
||
# 映射到新的特征v
|
||
v = self.v_pj(x).permute(0, 3, 1, 2) # B,C,H,W
|
||
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
|
||
v = self.unflod(v).reshape(B, self.num_heads, self.head_dim, self.kernel_size * self.kernel_size,
|
||
h * w).permute(0, 1, 4, 3, 2) # B,num_head,H*W,kxk,head_dim
|
||
|
||
# 生成Attention Map
|
||
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # B,H,W,C
|
||
attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size \
|
||
, self.kernel_size * self.kernel_size).permute(0, 2, 1, 3,
|
||
4) # B,num_head,H*W,kxk,kxk
|
||
attn = self.scale * attn
|
||
attn = attn.softmax(-1)
|
||
attn = self.attn_drop(attn)
|
||
|
||
# 获取weighted特征
|
||
out = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size,
|
||
h * w) # B,dimxkxk,H*W
|
||
out = F.fold(out, output_size=(H, W), kernel_size=self.kernel_size,
|
||
padding=self.padding, stride=self.stride) # B,C,H,W
|
||
out = self.proj(out.permute(0, 2, 3, 1)) # B,H,W,C
|
||
out = self.proj_drop(out)
|
||
|
||
return out
|
||
|
||
|
||
# 输入 B, H, W, C, 输出 B, H, W, C
|
||
if __name__ == '__main__':
|
||
block = OutlookAttention(dim=512).cuda()
|
||
input = torch.rand(1, 64, 64, 512).cuda()
|
||
output = block(input)
|
||
print(input.size(), output.size())
|