import torch import torch.nn as nn class Attention(nn.Module): def __init__(self, dim, num_heads=8, sr_ratio=1): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.dim = dim self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=True) self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) self.sr_norm = nn.LayerNorm(dim, eps=1e-6) self.up = nn.Sequential( nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim), nn.PixelShuffle(upscale_factor=sr_ratio) ) self.up_norm = nn.LayerNorm(dim, eps=1e-6) self.proj = nn.Conv2d(dim, dim, kernel_size=1) def forward(self, x): B, C, H, W = x.shape N = H * W q = self.q(x).reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2) if self.sr_ratio > 1: x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) x = self.sr_norm(x) x = x.permute(0, 2, 1).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio) else: x = x.reshape(B, C, N).permute(0, 2, 1) kv = self.kv(x).reshape(B, 2, self.num_heads, C // self.num_heads, -1).permute(1, 0, 2, 4, 3) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, C, H, W) identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio) identity = self.up(identity) identity = identity.flatten(2).transpose(1, 2).reshape(B, C, H, W) x = self.proj(x + identity) return x # 使用示例 model = Attention(dim=256, num_heads=16, sr_ratio=2) input_tensor = torch.randn(1, 256, 64, 64) # 假设输入形状为(B, C, H, W) output = model(input_tensor) print(output.shape)