91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, input_dim, output_dim):
|
|
super(MLP, self).__init__()
|
|
self.fc1 = nn.Linear(input_dim, output_dim)
|
|
self.act = nn.GELU() # 使用 GELU 激活函数
|
|
self.fc2 = nn.Linear(output_dim, input_dim)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.act(self.fc1(x)))
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, dim, heads):
|
|
super(Attention, self).__init__()
|
|
self.heads = heads
|
|
self.dim = dim
|
|
self.scale = dim ** -0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3)
|
|
self.attn_drop = nn.Dropout(0.1)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(0.1)
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1,
|
|
4) # (3, B, heads, N, head_dim)
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
return self.proj_drop(self.proj(out))
|
|
|
|
|
|
class ViTEncoder(nn.Module):
|
|
def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):
|
|
super(ViTEncoder, self).__init__()
|
|
self.patch_size = patch_size
|
|
self.dim = dim
|
|
self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
self.attention_layers = nn.ModuleList([
|
|
nn.Sequential(
|
|
Attention(dim, heads),
|
|
MLP(dim, mlp_dim)
|
|
) for _ in range(depth)
|
|
])
|
|
|
|
def forward(self, x):
|
|
x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)
|
|
x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)
|
|
|
|
for attention_layer in self.attention_layers:
|
|
x = attention_layer[0](x) + x # 自注意力
|
|
x = attention_layer[1](x) + x # MLP
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class ConvDecoder(nn.Module):
|
|
def __init__(self, dim=128, patch_size=8, img_size=96):
|
|
super(ConvDecoder, self).__init__()
|
|
self.dim = dim
|
|
self.patch_size = patch_size
|
|
self.img_size = img_size
|
|
self.decoder = nn.Sequential(
|
|
nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)
|
|
x = self.decoder(x)
|
|
return x
|
|
|
|
model = ConvDecoder()
|
|
x = torch.randn(1, 1, 256, 256)
|
|
output = model(x)
|
|
print(output.shape)
|