169 lines
5.7 KiB
Python
169 lines
5.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Unfold(nn.Module):
|
|
def __init__(self, kernel_size=3):
|
|
super().__init__()
|
|
|
|
self.kernel_size = kernel_size
|
|
|
|
weights = torch.eye(kernel_size ** 2)
|
|
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
|
|
self.weights = nn.Parameter(weights, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
|
|
return x.reshape(b, c * 9, h * w)
|
|
|
|
|
|
class Fold(nn.Module):
|
|
def __init__(self, kernel_size=3):
|
|
super().__init__()
|
|
|
|
self.kernel_size = kernel_size
|
|
|
|
weights = torch.eye(kernel_size ** 2)
|
|
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
|
|
self.weights = nn.Parameter(weights, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
b, _, h, w = x.shape
|
|
x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
|
|
return x
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.window_size = window_size
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Conv2d(dim, dim, 1)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
N = H * W
|
|
|
|
q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3,
|
|
dim=2) # (B, num_heads, head_dim, N)
|
|
|
|
attn = (k.transpose(-1, -2) @ q) * self.scale
|
|
|
|
attn = attn.softmax(dim=-2) # (B, h, N, N)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (v @ attn).reshape(B, C, H, W)
|
|
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class StokenAttention(nn.Module):
|
|
def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
|
proj_drop=0.):
|
|
super().__init__()
|
|
|
|
self.n_iter = n_iter
|
|
self.stoken_size = stoken_size
|
|
|
|
self.scale = dim ** - 0.5
|
|
|
|
self.unfold = Unfold(3)
|
|
self.fold = Fold(3)
|
|
|
|
self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
attn_drop=attn_drop, proj_drop=proj_drop)
|
|
|
|
def stoken_forward(self, x):
|
|
'''
|
|
x: (B, C, H, W)
|
|
'''
|
|
B, C, H0, W0 = x.shape
|
|
h, w = self.stoken_size
|
|
|
|
pad_l = pad_t = 0
|
|
pad_r = (w - W0 % w) % w
|
|
pad_b = (h - H0 % h) % h
|
|
if pad_r > 0 or pad_b > 0:
|
|
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
|
|
|
|
_, _, H, W = x.shape
|
|
|
|
hh, ww = H // h, W // w
|
|
|
|
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
|
|
|
|
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
|
|
|
|
with torch.no_grad():
|
|
for idx in range(self.n_iter):
|
|
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
|
|
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
|
|
affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
|
|
|
|
affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
|
|
|
|
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
|
|
|
|
affinity_matrix_sum = self.fold(affinity_matrix_sum)
|
|
if idx < self.n_iter - 1:
|
|
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
|
|
|
|
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(
|
|
B, C, hh, ww)
|
|
|
|
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
|
|
|
|
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
|
|
|
|
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
|
|
|
|
stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
|
|
|
|
stoken_features = self.stoken_refine(stoken_features)
|
|
|
|
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
|
|
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) # (B, hh*ww, C, 9)
|
|
|
|
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
|
|
|
|
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
|
|
|
|
if pad_r > 0 or pad_b > 0:
|
|
pixel_features = pixel_features[:, :, :H0, :W0]
|
|
|
|
return pixel_features
|
|
|
|
def direct_forward(self, x):
|
|
B, C, H, W = x.shape
|
|
stoken_features = x
|
|
stoken_features = self.stoken_refine(stoken_features)
|
|
return stoken_features
|
|
|
|
def forward(self, x):
|
|
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
|
|
return self.stoken_forward(x)
|
|
else:
|
|
return self.direct_forward(x)
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
input = torch.randn(3, 64, 64, 64).cuda()
|
|
se = StokenAttention(64, stoken_size=[8,8]).cuda()
|
|
output = se(input)
|
|
print(output.shape)
|