MAE_ATMO/torch_MAE_1d_20_patch_mask....

214 KiB
Raw Blame History

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pandas as pd
In [4]:
np.random.seed(0)
torch.random.manual_seed(0)
Out[4]:
<torch._C.Generator at 0x7fc99a573830>
In [5]:
# 计算图像数据中的最大像素值
max_pixel_value = 107.49169921875
In [6]:
class GrayScaleDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        data = np.load(file_path)[:,:,0] / max_pixel_value
        return torch.tensor(data, dtype=torch.float32).unsqueeze(0)
In [7]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')]  # 仅加载 .jpg 文件
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = np.random.choice(self.mask_filenames)
        mask_path = os.path.join(self.mask_dir, mask_idx)

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value  # 形状为 (96, 96, 1)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = masked_image[:, :, :1]  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (1, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

# 实例化数据集和数据加载器
image_dir = './out_mat/96/train/'
mask_dir = './out_mat/96/mask/20/'

print(f"checkpoint before Generator is OK")
checkpoint before Generator is OK
In [8]:
class PatchMasking:
    def __init__(self, patch_size, mask_ratio):
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio

    def __call__(self, x):
        batch_size, C, H, W = x.shape
        num_patches = (H // self.patch_size) * (W // self.patch_size)
        num_masked = int(num_patches * self.mask_ratio)
        
        # 为每个样本生成独立的mask
        masks = []
        for _ in range(batch_size):
            mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)
            mask[:num_masked] = 1
            mask = mask[torch.randperm(num_patches)]
            mask = mask.view(H // self.patch_size, W // self.patch_size)
            mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)
            masks.append(mask)
        
        # 将所有mask堆叠成一个批量张量
        masks = torch.stack(masks, dim=0)
        masks = torch.unsqueeze(masks, dim=1)
        
        # 应用mask到输入x上
        masked_x = x * (1- masks.float())
        return masked_x, masks
In [9]:
train_dir = './out_mat/96/train/'
train_dataset = GrayScaleDataset(train_dir)
val_dir = './out_mat/96/valid/'
val_dataset = GrayScaleDataset(val_dir)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

test_set = NO2Dataset('./out_mat/96/test/', mask_dir)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)
In [10]:
# 可视化特定特征的函数
def visualize_feature(input_feature,masked_feature, output_feature, title):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Input")
    plt.subplot(1, 3, 2)
    plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Masked")
    plt.subplot(1, 3, 3)
    plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Recovery")
    plt.show()
In [11]:
class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
        super(Conv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
        )
In [12]:
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,
                 bias=False):
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels),
            nn.ReLU()
        )
In [13]:
class SeparableBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
        super(SeparableBNReLU, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),
            # 分离卷积,仅调整空间信息
            norm_layer(in_channels),  # 对输入通道进行归一化
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),  # 这里进行升维操作
            nn.ReLU6()
        )
In [14]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 如果输入和输出通道不一致,进行降采样操作
        self.downsample = downsample
        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)
        return out
In [15]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)

        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)
        self.drop = nn.Dropout(drop, inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
In [16]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttentionBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(2, 0, 1)  # (B, C, H, W) -> (HW, B, C)

        # Apply multihead attention
        attn_output, _ = self.attention(x, x, x)

        # Apply normalization and dropout
        attn_output = self.norm(attn_output)
        attn_output = self.dropout(attn_output)

        # Reshape back to (B, C, H, W)
        attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)

        return attn_output
In [17]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self):
        super(SpatialAttentionBlock, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)

    def forward(self, x): #(B, 64, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)
        out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)
        out = torch.sigmoid(self.conv(out))#(B, 1, H, W)
        return x * out #(B, C, H, W)
In [18]:
class DecoderAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(DecoderAttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)
        self.spatial_attention = SpatialAttentionBlock()

    def forward(self, x):
        # 通道注意力
        b, c, h, w = x.size()
        avg_pool = F.adaptive_avg_pool2d(x, 1)
        max_pool = F.adaptive_max_pool2d(x, 1)

        avg_out = self.conv1(avg_pool)
        max_out = self.conv1(max_pool)

        out = avg_out + max_out
        out = torch.sigmoid(self.conv2(out))

        # 添加空间注意力
        out = x * out
        out = self.spatial_attention(out)
        return out
In [19]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super(SEBlock, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # 全局平均池化
            nn.Conv2d(in_channels, reduced_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(reduced_dim, in_channels, kernel_size=1),
            nn.Sigmoid()  # 使用Sigmoid是因为我们要对通道进行权重归一化
        )

    def forward(self, x):
        return x * self.se(x)
In [20]:
def masked_mse_loss(preds, target, mask):
    loss = (preds - target) ** 2
    loss = loss.mean(dim=-1)  # 对每个像素点求平均
    loss = (loss * mask).sum() / mask.sum()  # 只计算被mask的像素点的损失
    return loss
In [25]:
# 定义Masked Autoencoder模型
class MaskedAutoencoder(nn.Module):
    def __init__(self):
        super(MaskedAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            Conv(1, 32, kernel_size=3, stride=2),
            
            nn.ReLU(),
            
            SEBlock(32,32),
            
            ConvBNReLU(32, 64, kernel_size=3, stride=2),
            
            ResidualBlock(64,64),
            
            SeparableBNReLU(64, 128, kernel_size=3, stride=2),
            
            MultiHeadAttentionBlock(embed_dim=128, num_heads=4),
            
            SEBlock(128, 128),
            
            Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),            
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 修改为 output_padding=1
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# 实例化模型、损失函数和优化器
model = MaskedAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
In [22]:
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
            output = model(masked_data)
            loss = masked_mse_loss(output, data, mask)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
                output = model(masked_data)
                loss = masked_mse_loss(output, data, mask)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
In [23]:
# 数据准备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
In [27]:
train_model(model, train_loader, val_loader, epochs=130, criterion=criterion, optimizer=optimizer, device=device)
Epoch 1, Train Loss: 0.0185, Val Loss: 0.0199
Epoch 2, Train Loss: 0.0178, Val Loss: 0.0187
Epoch 3, Train Loss: 0.0174, Val Loss: 0.0217
Epoch 4, Train Loss: 0.0172, Val Loss: 0.0227
Epoch 5, Train Loss: 0.0167, Val Loss: 0.0180
Epoch 6, Train Loss: 0.0166, Val Loss: 0.0225
Epoch 7, Train Loss: 0.0163, Val Loss: 0.0183
Epoch 8, Train Loss: 0.0162, Val Loss: 0.0220
Epoch 9, Train Loss: 0.0161, Val Loss: 0.0181
Epoch 10, Train Loss: 0.0159, Val Loss: 0.0196
Epoch 11, Train Loss: 0.0159, Val Loss: 0.0210
Epoch 12, Train Loss: 0.0155, Val Loss: 0.0198
Epoch 13, Train Loss: 0.0154, Val Loss: 0.0212
Epoch 14, Train Loss: 0.0153, Val Loss: 0.0207
Epoch 15, Train Loss: 0.0153, Val Loss: 0.0216
Epoch 16, Train Loss: 0.0152, Val Loss: 0.0222
Epoch 17, Train Loss: 0.0152, Val Loss: 0.0225
Epoch 18, Train Loss: 0.0150, Val Loss: 0.0183
Epoch 19, Train Loss: 0.0151, Val Loss: 0.0242
Epoch 20, Train Loss: 0.0148, Val Loss: 0.0203
Epoch 21, Train Loss: 0.0148, Val Loss: 0.0211
Epoch 22, Train Loss: 0.0148, Val Loss: 0.0200
Epoch 23, Train Loss: 0.0146, Val Loss: 0.0191
Epoch 24, Train Loss: 0.0145, Val Loss: 0.0215
Epoch 25, Train Loss: 0.0145, Val Loss: 0.0196
Epoch 26, Train Loss: 0.0146, Val Loss: 0.0215
Epoch 27, Train Loss: 0.0144, Val Loss: 0.0195
Epoch 28, Train Loss: 0.0144, Val Loss: 0.0196
Epoch 29, Train Loss: 0.0143, Val Loss: 0.0182
Epoch 30, Train Loss: 0.0143, Val Loss: 0.0213
Epoch 31, Train Loss: 0.0142, Val Loss: 0.0178
Epoch 32, Train Loss: 0.0139, Val Loss: 0.0215
Epoch 33, Train Loss: 0.0135, Val Loss: 0.0171
Epoch 34, Train Loss: 0.0131, Val Loss: 0.0187
Epoch 35, Train Loss: 0.0128, Val Loss: 0.0171
Epoch 36, Train Loss: 0.0128, Val Loss: 0.0159
Epoch 37, Train Loss: 0.0127, Val Loss: 0.0170
Epoch 38, Train Loss: 0.0125, Val Loss: 0.0182
Epoch 39, Train Loss: 0.0124, Val Loss: 0.0155
Epoch 40, Train Loss: 0.0123, Val Loss: 0.0169
Epoch 41, Train Loss: 0.0122, Val Loss: 0.0160
Epoch 42, Train Loss: 0.0123, Val Loss: 0.0164
Epoch 43, Train Loss: 0.0120, Val Loss: 0.0154
Epoch 44, Train Loss: 0.0121, Val Loss: 0.0159
Epoch 45, Train Loss: 0.0119, Val Loss: 0.0152
Epoch 46, Train Loss: 0.0118, Val Loss: 0.0151
Epoch 47, Train Loss: 0.0119, Val Loss: 0.0135
Epoch 48, Train Loss: 0.0121, Val Loss: 0.0135
Epoch 49, Train Loss: 0.0118, Val Loss: 0.0162
Epoch 50, Train Loss: 0.0117, Val Loss: 0.0195
Epoch 51, Train Loss: 0.0116, Val Loss: 0.0160
Epoch 52, Train Loss: 0.0116, Val Loss: 0.0167
Epoch 53, Train Loss: 0.0116, Val Loss: 0.0149
Epoch 54, Train Loss: 0.0114, Val Loss: 0.0143
Epoch 55, Train Loss: 0.0115, Val Loss: 0.0136
Epoch 56, Train Loss: 0.0115, Val Loss: 0.0144
Epoch 57, Train Loss: 0.0115, Val Loss: 0.0158
Epoch 58, Train Loss: 0.0113, Val Loss: 0.0147
Epoch 59, Train Loss: 0.0112, Val Loss: 0.0142
Epoch 60, Train Loss: 0.0113, Val Loss: 0.0159
Epoch 61, Train Loss: 0.0112, Val Loss: 0.0153
Epoch 62, Train Loss: 0.0113, Val Loss: 0.0140
Epoch 63, Train Loss: 0.0112, Val Loss: 0.0156
Epoch 64, Train Loss: 0.0111, Val Loss: 0.0149
Epoch 65, Train Loss: 0.0112, Val Loss: 0.0154
Epoch 66, Train Loss: 0.0112, Val Loss: 0.0158
Epoch 67, Train Loss: 0.0111, Val Loss: 0.0136
Epoch 68, Train Loss: 0.0110, Val Loss: 0.0139
Epoch 69, Train Loss: 0.0110, Val Loss: 0.0142
Epoch 70, Train Loss: 0.0112, Val Loss: 0.0152
Epoch 71, Train Loss: 0.0109, Val Loss: 0.0151
Epoch 72, Train Loss: 0.0110, Val Loss: 0.0162
Epoch 73, Train Loss: 0.0110, Val Loss: 0.0162
Epoch 74, Train Loss: 0.0109, Val Loss: 0.0176
Epoch 75, Train Loss: 0.0109, Val Loss: 0.0143
Epoch 76, Train Loss: 0.0109, Val Loss: 0.0147
Epoch 77, Train Loss: 0.0108, Val Loss: 0.0141
Epoch 78, Train Loss: 0.0109, Val Loss: 0.0145
Epoch 79, Train Loss: 0.0108, Val Loss: 0.0140
Epoch 80, Train Loss: 0.0109, Val Loss: 0.0135
Epoch 81, Train Loss: 0.0108, Val Loss: 0.0145
Epoch 82, Train Loss: 0.0108, Val Loss: 0.0126
Epoch 83, Train Loss: 0.0108, Val Loss: 0.0145
Epoch 84, Train Loss: 0.0107, Val Loss: 0.0135
Epoch 85, Train Loss: 0.0108, Val Loss: 0.0140
Epoch 86, Train Loss: 0.0107, Val Loss: 0.0143
Epoch 87, Train Loss: 0.0107, Val Loss: 0.0146
Epoch 88, Train Loss: 0.0107, Val Loss: 0.0136
Epoch 111, Train Loss: 0.0094, Val Loss: 0.0120
Epoch 112, Train Loss: 0.0094, Val Loss: 0.0114
Epoch 113, Train Loss: 0.0095, Val Loss: 0.0128
Epoch 114, Train Loss: 0.0093, Val Loss: 0.0125
Epoch 115, Train Loss: 0.0093, Val Loss: 0.0124
Epoch 116, Train Loss: 0.0093, Val Loss: 0.0114
Epoch 117, Train Loss: 0.0093, Val Loss: 0.0127
Epoch 118, Train Loss: 0.0093, Val Loss: 0.0122
Epoch 119, Train Loss: 0.0093, Val Loss: 0.0116
Epoch 120, Train Loss: 0.0092, Val Loss: 0.0114
Epoch 121, Train Loss: 0.0092, Val Loss: 0.0130
Epoch 122, Train Loss: 0.0092, Val Loss: 0.0114
Epoch 123, Train Loss: 0.0093, Val Loss: 0.0113
Epoch 124, Train Loss: 0.0092, Val Loss: 0.0120
Epoch 125, Train Loss: 0.0091, Val Loss: 0.0110
Epoch 126, Train Loss: 0.0091, Val Loss: 0.0128
Epoch 127, Train Loss: 0.0091, Val Loss: 0.0129
Epoch 128, Train Loss: 0.0092, Val Loss: 0.0126
Epoch 129, Train Loss: 0.0092, Val Loss: 0.0113
Epoch 130, Train Loss: 0.0091, Val Loss: 0.0109
In [30]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [31]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [32]:
eva_list_frame = list()
device = 'cpu'
model = model.to(device)
best_mape = 1
best_img = None
best_mask = None
best_recov = None
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = model(X)
        rev_data = y * max_pixel_value
        rev_recon = reconstructed * max_pixel_value
        # todo: 这里需要只评估修补出来的模块
        for i, sample in enumerate(rev_data):
            used_mask = mask_rev[i]
            data_label = sample[0] * used_mask
            recon_no2 = rev_recon[i][0] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
            if mape < best_mape:
                best_recov = rev_recon[i][0].numpy()
                best_mask = used_mask.numpy()
                best_img = sample[0].numpy()
                best_mape = mape
In [33]:
pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()
Out[33]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 5.297680 6.225729 0.489679 -1.978159 -0.362509 0.352984
std 3.930302 4.176386 0.191670 2.447883 1.074637 0.201559
min 0.996953 1.279405 0.202344 -28.276637 -9.562830 -0.500861
25% 2.103293 2.741658 0.353414 -2.891019 -0.796581 0.225314
50% 3.190869 4.148710 0.457116 -1.093823 0.044020 0.365110
75% 8.378542 9.440538 0.586501 -0.406974 0.355992 0.498017
max 21.329165 23.047779 2.242282 0.592645 0.829324 0.839954
In [37]:
eva_list = list()
device = 'cpu'
model = model.to(device)
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = model(X)
        rev_data = y * max_pixel_value
        rev_recon = reconstructed * max_pixel_value
        # todo: 这里需要只评估修补出来的模块
        data_label = torch.squeeze(rev_data, dim=1) * mask_rev
        data_label = data_label[mask_rev==1]
        recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev
        recon_no2 = recon_no2[mask_rev==1]
        mae = mean_absolute_error(data_label, recon_no2)
        rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
        mape = mean_absolute_percentage_error(data_label, recon_no2)
        r2 = r2_score(data_label, recon_no2)
        ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
        eva_list.append([mae, rmse, mape, r2, ioa])
In [42]:
torch.save(model, './models/MAE/final_patch_20.pt')
In [37]:
# 可视化特定特征的函数
def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 4, 1)
    plt.imshow(input_feature, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 2)
    plt.imshow(masked_feature, cmap='gray')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 3)
    plt.imshow(recov_region, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 4)
    plt.imshow(output_feature, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    # plt.savefig('./figures/result/20_samples.png')
    plt.show()
In [38]:
best_mask_cp = np.where(best_mask == 0, np.nan, best_mask)
In [40]:
visualize_rst(best_img, best_mask, best_recov*best_mask_cp, best_recov, '')
No description has been provided for this image
In [ ]: