67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
|
import numpy as np
|
|||
|
import torch
|
|||
|
from torch import nn
|
|||
|
from torch.nn import init
|
|||
|
import math
|
|||
|
from torch.nn import functional as F
|
|||
|
|
|||
|
|
|||
|
class OutlookAttention(nn.Module):
|
|||
|
|
|||
|
def __init__(self, dim, num_heads=1, kernel_size=3, padding=1, stride=1, qkv_bias=False,
|
|||
|
attn_drop=0.1):
|
|||
|
super().__init__()
|
|||
|
self.dim = dim
|
|||
|
self.num_heads = num_heads
|
|||
|
self.head_dim = dim // num_heads
|
|||
|
self.kernel_size = kernel_size
|
|||
|
self.padding = padding
|
|||
|
self.stride = stride
|
|||
|
self.scale = self.head_dim ** (-0.5)
|
|||
|
|
|||
|
self.v_pj = nn.Linear(dim, dim, bias=qkv_bias)
|
|||
|
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
|
|||
|
|
|||
|
self.attn_drop = nn.Dropout(attn_drop)
|
|||
|
self.proj = nn.Linear(dim, dim)
|
|||
|
self.proj_drop = nn.Dropout(attn_drop)
|
|||
|
|
|||
|
self.unflod = nn.Unfold(kernel_size, padding, stride) # 手动卷积
|
|||
|
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
B, H, W, C = x.shape
|
|||
|
|
|||
|
# 映射到新的特征v
|
|||
|
v = self.v_pj(x).permute(0, 3, 1, 2) # B,C,H,W
|
|||
|
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
|
|||
|
v = self.unflod(v).reshape(B, self.num_heads, self.head_dim, self.kernel_size * self.kernel_size,
|
|||
|
h * w).permute(0, 1, 4, 3, 2) # B,num_head,H*W,kxk,head_dim
|
|||
|
|
|||
|
# 生成Attention Map
|
|||
|
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # B,H,W,C
|
|||
|
attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size \
|
|||
|
, self.kernel_size * self.kernel_size).permute(0, 2, 1, 3,
|
|||
|
4) # B,num_head,H*W,kxk,kxk
|
|||
|
attn = self.scale * attn
|
|||
|
attn = attn.softmax(-1)
|
|||
|
attn = self.attn_drop(attn)
|
|||
|
|
|||
|
# 获取weighted特征
|
|||
|
out = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size,
|
|||
|
h * w) # B,dimxkxk,H*W
|
|||
|
out = F.fold(out, output_size=(H, W), kernel_size=self.kernel_size,
|
|||
|
padding=self.padding, stride=self.stride) # B,C,H,W
|
|||
|
out = self.proj(out.permute(0, 2, 3, 1)) # B,H,W,C
|
|||
|
out = self.proj_drop(out)
|
|||
|
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
# 输入 B, H, W, C, 输出 B, H, W, C
|
|||
|
if __name__ == '__main__':
|
|||
|
block = OutlookAttention(dim=512).cuda()
|
|||
|
input = torch.rand(1, 64, 64, 512).cuda()
|
|||
|
output = block(input)
|
|||
|
print(input.size(), output.size())
|