import torch import torch.nn as nn import torch.nn.functional as F class SEBlock(nn.Module): def __init__(self, in_channels, reduced_dim): super(SEBlock, self).__init__() self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, reduced_dim, kernel_size=1), nn.ReLU(), nn.Conv2d(reduced_dim, in_channels, kernel_size=1), nn.Sigmoid() ) def forward(self, x): return x * self.se(x) class Conv(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False): super(Conv, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2) ) class ConvBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBNReLU, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels), nn.ReLU() ) class SeparableBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): super(SeparableBNReLU, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), # 分离卷积,仅调整空间信息 norm_layer(in_channels), # 对输入通道进行归一化 nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作 nn.ReLU6() ) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) # 如果输入和输出通道不一致,进行降采样操作 self.downsample = downsample if in_channels != out_channels or stride != 1: self.downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True) self.drop = nn.Dropout(drop, inplace=True) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class MultiHeadAttentionBlock(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super(MultiHeadAttentionBlock, self).__init__() self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): B, C, H, W = x.shape x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C) attn_output, _ = self.attention(x, x, x) attn_output = self.norm(attn_output) attn_output = self.dropout(attn_output) attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) return attn_output class SpatialAttentionBlock(nn.Module): def __init__(self): super(SpatialAttentionBlock, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False) def forward(self, x): #(B, 64, H, W) avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W) max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W) out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W) out = torch.sigmoid(self.conv(out))#(B, 1, H, W) return x * out #(B, C, H, W) class DecoderAttentionBlock(nn.Module): def __init__(self, in_channels): super(DecoderAttentionBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1) self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1) self.spatial_attention = SpatialAttentionBlock() def forward(self, x): # 通道注意力 b, c, h, w = x.size() avg_pool = F.adaptive_avg_pool2d(x, 1) max_pool = F.adaptive_max_pool2d(x, 1) avg_out = self.conv1(avg_pool) max_out = self.conv1(max_pool) out = avg_out + max_out out = torch.sigmoid(self.conv2(out)) # 添加空间注意力 out = x * out out = self.spatial_attention(out) return out class MaskedAutoencoder(nn.Module): def __init__(self): super(MaskedAutoencoder, self).__init__() self.encoder = nn.Sequential( Conv(1, 32, kernel_size=3, stride=2), nn.ReLU(), SEBlock(32,32), ConvBNReLU(32, 64, kernel_size=3, stride=2), ResidualBlock(64,64), SeparableBNReLU(64, 128, kernel_size=3, stride=2), MultiHeadAttentionBlock(embed_dim=128, num_heads=4), SEBlock(128, 128) ) self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), DecoderAttentionBlock(32), nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), DecoderAttentionBlock(16), nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1 nn.Sigmoid() ) def forward(self, x): encoded = self.encoder(x) print("Encoded size:", encoded.size()) decoded = self.decoder(encoded) print("Encoded size:", decoded.size()) return decoded model = MaskedAutoencoder() x = torch.randn(1, 1, 256, 256) output = model(x) print(output.shape)