Tan_pytorch_segmentation/pytorch_segmentation/MAE反演预测/vit.py

100 lines
3.1 KiB
Python

import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.act = nn.GELU() # 使用 GELU 激活函数
self.fc2 = nn.Linear(output_dim, input_dim)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
def __init__(self, dim, heads):
super(Attention, self).__init__()
self.heads = heads
self.dim = dim
self.scale = dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.attn_drop = nn.Dropout(0.1)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(0.1)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1,
4) # (3, B, heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(out))
class ViTEncoder(nn.Module):
def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):
super(ViTEncoder, self).__init__()
self.patch_size = patch_size
self.dim = dim
self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
self.attention_layers = nn.ModuleList([
nn.Sequential(
Attention(dim, heads),
MLP(dim, mlp_dim)
) for _ in range(depth)
])
def forward(self, x):
x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)
x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)
for attention_layer in self.attention_layers:
x = attention_layer[0](x) + x # 自注意力
x = attention_layer[1](x) + x # MLP
return x
class ConvDecoder(nn.Module):
def __init__(self, dim=128, patch_size=8, img_size=96):
super(ConvDecoder, self).__init__()
self.dim = dim
self.patch_size = patch_size
self.img_size = img_size
self.decoder = nn.Sequential(
nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),
nn.ReLU(),
nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)
x = self.decoder(x)
return x
class MAEModel(nn.Module):
def __init__(self, encoder, decoder):
super(MAEModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, mask):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded * mask
model = MAEModel()
x = torch.randn(1, 1, 256, 256)
output = model(x)
print(output.shape)