import torch import torch.nn as nn class MLP(nn.Module): def __init__(self, input_dim, output_dim): super(MLP, self).__init__() self.fc1 = nn.Linear(input_dim, output_dim) self.act = nn.GELU() # 使用 GELU 激活函数 self.fc2 = nn.Linear(output_dim, input_dim) def forward(self, x): return self.fc2(self.act(self.fc1(x))) class Attention(nn.Module): def __init__(self, dim, heads): super(Attention, self).__init__() self.heads = heads self.dim = dim self.scale = dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.attn_drop = nn.Dropout(0.1) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(0.1) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) out = (attn @ v).transpose(1, 2).reshape(B, N, C) return self.proj_drop(self.proj(out)) class ViTEncoder(nn.Module): def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256): super(ViTEncoder, self).__init__() self.patch_size = patch_size self.dim = dim self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size) self.attention_layers = nn.ModuleList([ nn.Sequential( Attention(dim, heads), MLP(dim, mlp_dim) ) for _ in range(depth) ]) def forward(self, x): x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w) x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim) for attention_layer in self.attention_layers: x = attention_layer[0](x) + x # 自注意力 x = attention_layer[1](x) + x # MLP return x class ConvDecoder(nn.Module): def __init__(self, dim=128, patch_size=8, img_size=96): super(ConvDecoder, self).__init__() self.dim = dim self.patch_size = patch_size self.img_size = img_size self.decoder = nn.Sequential( nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size), nn.ReLU(), nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1) ) def forward(self, x): x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size) x = self.decoder(x) return x class MAEModel(nn.Module): def __init__(self, encoder, decoder): super(MAEModel, self).__init__() self.encoder = encoder self.decoder = decoder def forward(self, x, mask): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded * mask model = MAEModel() x = torch.randn(1, 1, 256, 256) output = model(x) print(output.shape)