Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/(cvpr2023)STVit.py

169 lines
5.7 KiB
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unfold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x):
b, c, h, w = x.shape
x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
return x.reshape(b, c * 9, h * w)
class Fold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x):
b, _, h, w = x.shape
x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
return x
class Attention(nn.Module):
def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.window_size = window_size
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, 1)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
N = H * W
q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3,
dim=2) # (B, num_heads, head_dim, N)
attn = (k.transpose(-1, -2) @ q) * self.scale
attn = attn.softmax(dim=-2) # (B, h, N, N)
attn = self.attn_drop(attn)
x = (v @ attn).reshape(B, C, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class StokenAttention(nn.Module):
def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.n_iter = n_iter
self.stoken_size = stoken_size
self.scale = dim ** - 0.5
self.unfold = Unfold(3)
self.fold = Fold(3)
self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=proj_drop)
def stoken_forward(self, x):
'''
x: (B, C, H, W)
'''
B, C, H0, W0 = x.shape
h, w = self.stoken_size
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, _, H, W = x.shape
hh, ww = H // h, W // w
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
with torch.no_grad():
for idx in range(self.n_iter):
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(
B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
stoken_features = self.stoken_refine(stoken_features)
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) # (B, hh*ww, C, 9)
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0]
return pixel_features
def direct_forward(self, x):
B, C, H, W = x.shape
stoken_features = x
stoken_features = self.stoken_refine(stoken_features)
return stoken_features
def forward(self, x):
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x)
else:
return self.direct_forward(x)
# 输入 N C H W, 输出 N C H W
if __name__ == '__main__':
input = torch.randn(3, 64, 64, 64).cuda()
se = StokenAttention(64, stoken_size=[8,8]).cuda()
output = se(input)
print(output.shape)