65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self,
|
|
dim,
|
|
num_heads=8,
|
|
sr_ratio=1):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = head_dim ** -0.5
|
|
self.dim = dim
|
|
|
|
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
|
|
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True)
|
|
|
|
self.sr_ratio = sr_ratio
|
|
if sr_ratio > 1:
|
|
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
|
|
self.sr_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
|
|
self.up = nn.Sequential(
|
|
nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim),
|
|
nn.PixelShuffle(upscale_factor=sr_ratio)
|
|
)
|
|
self.up_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
|
|
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
N = H * W
|
|
|
|
q = self.q(x).reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2)
|
|
|
|
if self.sr_ratio > 1:
|
|
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
|
|
x = self.sr_norm(x)
|
|
x = x.permute(0, 2, 1).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio)
|
|
else:
|
|
x = x.reshape(B, C, N).permute(0, 2, 1)
|
|
|
|
kv = self.kv(x).reshape(B, 2, self.num_heads, C // self.num_heads, -1).permute(1, 0, 2, 4, 3)
|
|
k, v = kv[0], kv[1]
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, C, H, W)
|
|
|
|
identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio)
|
|
identity = self.up(identity)
|
|
identity = identity.flatten(2).transpose(1, 2).reshape(B, C, H, W)
|
|
x = self.proj(x + identity)
|
|
return x
|
|
|
|
|
|
# 使用示例
|
|
model = Attention(dim=256, num_heads=16, sr_ratio=2)
|
|
input_tensor = torch.randn(1, 256, 64, 64) # 假设输入形状为(B, C, H, W)
|
|
output = model(input_tensor)
|
|
print(output.shape)
|