Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/OutlookAttention.py

67 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from torch import nn
from torch.nn import init
import math
from torch.nn import functional as F
class OutlookAttention(nn.Module):
def __init__(self, dim, num_heads=1, kernel_size=3, padding=1, stride=1, qkv_bias=False,
attn_drop=0.1):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.scale = self.head_dim ** (-0.5)
self.v_pj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(attn_drop)
self.unflod = nn.Unfold(kernel_size, padding, stride) # 手动卷积
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
def forward(self, x):
B, H, W, C = x.shape
# 映射到新的特征v
v = self.v_pj(x).permute(0, 3, 1, 2) # B,C,H,W
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
v = self.unflod(v).reshape(B, self.num_heads, self.head_dim, self.kernel_size * self.kernel_size,
h * w).permute(0, 1, 4, 3, 2) # B,num_head,H*W,kxk,head_dim
# 生成Attention Map
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # B,H,W,C
attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size \
, self.kernel_size * self.kernel_size).permute(0, 2, 1, 3,
4) # Bnum_headH*W,kxk,kxk
attn = self.scale * attn
attn = attn.softmax(-1)
attn = self.attn_drop(attn)
# 获取weighted特征
out = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size,
h * w) # B,dimxkxk,H*W
out = F.fold(out, output_size=(H, W), kernel_size=self.kernel_size,
padding=self.padding, stride=self.stride) # B,C,H,W
out = self.proj(out.permute(0, 2, 3, 1)) # B,H,W,C
out = self.proj_drop(out)
return out
# 输入 B, H, W, C, 输出 B, H, W, C
if __name__ == '__main__':
block = OutlookAttention(dim=512).cuda()
input = torch.rand(1, 64, 64, 512).cuda()
output = block(input)
print(input.size(), output.size())