MAE_ATMO/torch_MAE_1d_final_mixed.ipynb

77 KiB
Raw Blame History

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
In [2]:
max_pixel_value = 107.49169921875
print(f"Maximum pixel value in the dataset: {max_pixel_value}")
Maximum pixel value in the dataset: 107.49169921875
In [3]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        mask_rates = [10, 20, 30, 40]
        self.mask_filenames = list()
        for rate in mask_rates:
            local_masks = [f"{f'{mask_dir}{rate}/{f}'}" for f in os.listdir(f'{mask_dir}{rate}') if f.endswith('.jpg')]
            self.mask_filenames.extend(local_masks)
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = np.random.choice(self.mask_filenames)
        mask_path = mask_idx
        select_rate = mask_idx.split('/')[4]

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value  # 形状为 (96, 96, 1)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = masked_image[:, :, :1]  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (1, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32), select_rate

# 实例化数据集和数据加载器
image_dir = './out_mat/96/train/'
mask_dir = './out_mat/96/mask/'

print(f"checkpoint before Generator is OK")
checkpoint before Generator is OK
In [4]:
dataset = NO2Dataset(image_dir, mask_dir)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)
val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)
test_set = NO2Dataset('./out_mat/96/test/', mask_dir)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)
In [5]:
# 可视化特定特征的函数
def visualize_feature(input_feature,masked_feature, output_feature, title):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Input")
    plt.subplot(1, 3, 2)
    plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Masked")
    plt.subplot(1, 3, 3)
    plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')
    plt.title(title + " Recovery")
    plt.show()
In [6]:
class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
        super(Conv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
        )
In [7]:
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,
                 bias=False):
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels),
            nn.ReLU()
        )
In [8]:
class SeparableBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
        super(SeparableBNReLU, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),
            # 分离卷积,仅调整空间信息
            norm_layer(in_channels),  # 对输入通道进行归一化
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),  # 这里进行升维操作
            nn.ReLU6()
        )
In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 如果输入和输出通道不一致,进行降采样操作
        self.downsample = downsample
        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)
        return out
In [10]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)

        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)
        self.drop = nn.Dropout(drop, inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
In [11]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttentionBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(2, 0, 1)  # (B, C, H, W) -> (HW, B, C)

        # Apply multihead attention
        attn_output, _ = self.attention(x, x, x)

        # Apply normalization and dropout
        attn_output = self.norm(attn_output)
        attn_output = self.dropout(attn_output)

        # Reshape back to (B, C, H, W)
        attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)

        return attn_output
In [12]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self):
        super(SpatialAttentionBlock, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)

    def forward(self, x): #(B, 64, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)
        out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)
        out = torch.sigmoid(self.conv(out))#(B, 1, H, W)
        return x * out #(B, C, H, W)
In [13]:
class DecoderAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(DecoderAttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)
        self.spatial_attention = SpatialAttentionBlock()

    def forward(self, x):
        # 通道注意力
        b, c, h, w = x.size()
        avg_pool = F.adaptive_avg_pool2d(x, 1)
        max_pool = F.adaptive_max_pool2d(x, 1)

        avg_out = self.conv1(avg_pool)
        max_out = self.conv1(max_pool)

        out = avg_out + max_out
        out = torch.sigmoid(self.conv2(out))

        # 添加空间注意力
        out = x * out
        out = self.spatial_attention(out)
        return out
In [14]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super(SEBlock, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # 全局平均池化
            nn.Conv2d(in_channels, reduced_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(reduced_dim, in_channels, kernel_size=1),
            nn.Sigmoid()  # 使用Sigmoid是因为我们要对通道进行权重归一化
        )

    def forward(self, x):
        return x * self.se(x)
In [15]:
# 定义Masked Autoencoder模型
class MaskedAutoencoder(nn.Module):
    def __init__(self):
        super(MaskedAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            Conv(1, 32, kernel_size=3, stride=2),
            
            nn.ReLU(),
            
            SEBlock(32,32),
            
            ConvBNReLU(32, 64, kernel_size=3, stride=2),
            
            ResidualBlock(64,64),
            
            SeparableBNReLU(64, 128, kernel_size=3, stride=2),
            
            MultiHeadAttentionBlock(embed_dim=128, num_heads=4),
            
            SEBlock(128, 128)
        )
        self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            
            DecoderAttentionBlock(32),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            
            DecoderAttentionBlock(16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 修改为 output_padding=1
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# 实例化模型、损失函数和优化器
model = MaskedAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
In [16]:
def masked_mse_loss(preds, target, mask):
    loss = (preds - target) ** 2
    loss = loss.mean(dim=-1)  # 对每个像素点求平均
    loss = (loss * mask).sum() / mask.sum()  # 只计算被mask的像素点的损失
    return loss
In [17]:
# 训练函数
def train_epoch(model, device, data_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    miss_counts = list()
    for batch_idx, (X, y, mask, miss_rate) in enumerate(data_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        miss_counts.append(miss_rate)
        optimizer.zero_grad()
        reconstructed = model(X)
        loss = masked_mse_loss(reconstructed, y, mask)
        # loss = criterion(reconstructed, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / (batch_idx + 1), miss_counts
In [18]:
# 评估函数
def evaluate(model, device, data_loader, criterion):
    model.eval()
    running_loss = 0.0
    miss_counts = list()
    with torch.no_grad():
        for batch_idx, (X, y, mask, miss_rate) in enumerate(data_loader):
            X, y, mask = X.to(device), y.to(device), mask.to(device)
            miss_counts.append(miss_rate)
            reconstructed = model(X)
            if batch_idx == 8:
                rand_ind = np.random.randint(0, len(y))
                # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')
            loss = masked_mse_loss(reconstructed, y, mask)
            running_loss += loss.item()
    return running_loss / (batch_idx + 1), miss_counts
In [19]:
# 数据准备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
In [20]:
model = model.to(device)

num_epochs = 150
train_losses = list()
val_losses = list()
for epoch in range(num_epochs):
    train_loss, train_counts = train_epoch(model, device, dataloader, criterion, optimizer)
    train_losses.append(train_loss)
    val_loss, val_counts = evaluate(model, device, val_loader, criterion)
    val_losses.append(val_loss)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')

# 测试模型
test_loss = evaluate(model, device, test_loader, criterion)
print(f'Test Loss: {test_loss[0]}')
/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
Epoch 1, Train Loss: 3.759178739843186, Val Loss: 0.1379857260122228
Epoch 2, Train Loss: 0.09902132054764118, Val Loss: 0.066096671370428
Epoch 3, Train Loss: 0.060244543255088434, Val Loss: 0.05034376319442222
Epoch 4, Train Loss: 0.04942069527956002, Val Loss: 0.04460687851950304
Epoch 5, Train Loss: 0.04382758207940029, Val Loss: 0.0369152329417307
Epoch 6, Train Loss: 0.03961431584432365, Val Loss: 0.033898868973353015
Epoch 7, Train Loss: 0.03653587933861468, Val Loss: 0.03190647060079361
Epoch 8, Train Loss: 0.03421460006956421, Val Loss: 0.030460054360663714
Epoch 9, Train Loss: 0.03215051376434604, Val Loss: 0.03062500929765737
Epoch 10, Train Loss: 0.031739671104119724, Val Loss: 0.029085035394154378
Epoch 11, Train Loss: 0.030470874753188004, Val Loss: 0.03185694292187691
Epoch 12, Train Loss: 0.029636846623566162, Val Loss: 0.029310374951629498
Epoch 13, Train Loss: 0.028289151542851228, Val Loss: 0.02720484949314772
Epoch 14, Train Loss: 0.027910822102327666, Val Loss: 0.028894296833383504
Epoch 15, Train Loss: 0.027092363841332602, Val Loss: 0.02946079163742599
Epoch 16, Train Loss: 0.025776214282692334, Val Loss: 0.024672900368251018
Epoch 17, Train Loss: 0.025803192848402063, Val Loss: 0.02488229790730263
Epoch 18, Train Loss: 0.025352436108915716, Val Loss: 0.02426056825180552
Epoch 19, Train Loss: 0.024724755284675, Val Loss: 0.023613420885000656
Epoch 20, Train Loss: 0.02373662724663196, Val Loss: 0.023868454147630662
Epoch 21, Train Loss: 0.023606173005668026, Val Loss: 0.022293920976234907
Epoch 22, Train Loss: 0.02291965261814697, Val Loss: 0.0231649547036904
Epoch 23, Train Loss: 0.022957429811180208, Val Loss: 0.022116250789432385
Epoch 24, Train Loss: 0.022525311819763416, Val Loss: 0.02422845282994989
Epoch 25, Train Loss: 0.02231395777101009, Val Loss: 0.02212312592388089
Epoch 26, Train Loss: 0.02209535693420035, Val Loss: 0.02158943160589951
Epoch 27, Train Loss: 0.021671999831857722, Val Loss: 0.022256974825885758
Epoch 28, Train Loss: 0.021378441671417517, Val Loss: 0.021293755787522045
Epoch 29, Train Loss: 0.021532584222381194, Val Loss: 0.021740848698945187
Epoch 30, Train Loss: 0.02089789963625906, Val Loss: 0.022172707369300857
Epoch 31, Train Loss: 0.020911543732553578, Val Loss: 0.020904658445671423
Epoch 32, Train Loss: 0.020589363574090472, Val Loss: 0.021264061137144245
Epoch 33, Train Loss: 0.02011841800037112, Val Loss: 0.022388043521500346
Epoch 34, Train Loss: 0.020350060138281025, Val Loss: 0.020872680664952122
Epoch 35, Train Loss: 0.019910728570038193, Val Loss: 0.02008631487668895
Epoch 36, Train Loss: 0.01966284622291201, Val Loss: 0.02018301992385245
Epoch 37, Train Loss: 0.019478668659283785, Val Loss: 0.020117887351383913
Epoch 38, Train Loss: 0.019168558606262983, Val Loss: 0.020217864148652377
Epoch 39, Train Loss: 0.018900538525102956, Val Loss: 0.019784750694881625
Epoch 40, Train Loss: 0.019068713380139695, Val Loss: 0.020406662806201337
Epoch 41, Train Loss: 0.01922704772488994, Val Loss: 0.019463480088804195
Epoch 42, Train Loss: 0.018683298484257392, Val Loss: 0.019570431866641366
Epoch 43, Train Loss: 0.018411033715535863, Val Loss: 0.019696261789371717
Epoch 44, Train Loss: 0.018502752826901142, Val Loss: 0.0193116083574384
Epoch 45, Train Loss: 0.01851825592772028, Val Loss: 0.021103291230192826
Epoch 46, Train Loss: 0.01816361720125641, Val Loss: 0.020114433075954664
Epoch 47, Train Loss: 0.018051497934555464, Val Loss: 0.020221358179045256
Epoch 48, Train Loss: 0.01811225383885597, Val Loss: 0.01961083782475386
Epoch 49, Train Loss: 0.017867776890548224, Val Loss: 0.018948225665893128
Epoch 50, Train Loss: 0.01761771424152135, Val Loss: 0.01865902607009482
Epoch 51, Train Loss: 0.01793021524467608, Val Loss: 0.018359918592136298
Epoch 52, Train Loss: 0.017610817650805393, Val Loss: 0.018650228838756014
Epoch 53, Train Loss: 0.017737194443451305, Val Loss: 0.018363466583637158
Epoch 54, Train Loss: 0.017543190524302886, Val Loss: 0.019013355055184505
Epoch 55, Train Loss: 0.01778105637236859, Val Loss: 0.018212769875553116
Epoch 56, Train Loss: 0.017451271454861576, Val Loss: 0.018818481644587732
Epoch 57, Train Loss: 0.017273589150989026, Val Loss: 0.01801557773585195
Epoch 58, Train Loss: 0.01728663447816549, Val Loss: 0.01771288837737112
Epoch 59, Train Loss: 0.017209396768878237, Val Loss: 0.018658861782012592
Epoch 60, Train Loss: 0.017015971490694434, Val Loss: 0.01875163140748419
Epoch 61, Train Loss: 0.01697286305744112, Val Loss: 0.01831459281827087
Epoch 62, Train Loss: 0.01689975440466518, Val Loss: 0.018071504671182206
Epoch 63, Train Loss: 0.016585711293974133, Val Loss: 0.01783462390025605
Epoch 64, Train Loss: 0.016933080276839756, Val Loss: 0.018715852857636873
Epoch 65, Train Loss: 0.016899143777894633, Val Loss: 0.019256604974394412
Epoch 66, Train Loss: 0.016631374423031173, Val Loss: 0.018876284666693034
Epoch 67, Train Loss: 0.016569798094839855, Val Loss: 0.018378769520169765
Epoch 68, Train Loss: 0.016539438030544366, Val Loss: 0.018459608500350767
Epoch 69, Train Loss: 0.01645555520323261, Val Loss: 0.01851357322241833
Epoch 70, Train Loss: 0.01667448620726332, Val Loss: 0.017527391814362647
Epoch 71, Train Loss: 0.01630861950708491, Val Loss: 0.01862382395331984
Epoch 72, Train Loss: 0.016292595119621053, Val Loss: 0.01898773131308271
Epoch 73, Train Loss: 0.016312802497867904, Val Loss: 0.017515668033886312
Epoch 74, Train Loss: 0.01634560936714331, Val Loss: 0.017603496631690814
Epoch 75, Train Loss: 0.016150180214757556, Val Loss: 0.0177685193606277
Epoch 76, Train Loss: 0.016183897565479912, Val Loss: 0.01790037954142734
Epoch 77, Train Loss: 0.016441928089092794, Val Loss: 0.0177356356671497
Epoch 78, Train Loss: 0.016029272553773875, Val Loss: 0.01720855048676925
Epoch 79, Train Loss: 0.015830894611312443, Val Loss: 0.017439508657735674
Epoch 80, Train Loss: 0.015893817865891318, Val Loss: 0.017185933985260884
Epoch 81, Train Loss: 0.01587246311160081, Val Loss: 0.017182132229208946
Epoch 82, Train Loss: 0.015938340017848322, Val Loss: 0.01732705053942862
Epoch 83, Train Loss: 0.015770130625894767, Val Loss: 0.01730423010607709
Epoch 84, Train Loss: 0.015774958316931886, Val Loss: 0.01693567380642713
Epoch 85, Train Loss: 0.015681640634928166, Val Loss: 0.01731172299929964
Epoch 86, Train Loss: 0.015522310860080725, Val Loss: 0.01708351758155805
Epoch 87, Train Loss: 0.015825702162664473, Val Loss: 0.01767030195680572
Epoch 88, Train Loss: 0.015465608916053789, Val Loss: 0.0169600204689734
Epoch 89, Train Loss: 0.015413585239263812, Val Loss: 0.016799337550330518
Epoch 90, Train Loss: 0.015661140533975153, Val Loss: 0.017084516890680614
Epoch 91, Train Loss: 0.015471032805045684, Val Loss: 0.017242409135979502
Epoch 92, Train Loss: 0.015306838647725337, Val Loss: 0.016721693103882804
Epoch 93, Train Loss: 0.01516885641721661, Val Loss: 0.01838143560479381
Epoch 94, Train Loss: 0.015182504183100314, Val Loss: 0.017020777451680666
Epoch 95, Train Loss: 0.01524644939264541, Val Loss: 0.01649292297105291
Epoch 96, Train Loss: 0.015118425159434382, Val Loss: 0.017190173087613798
Epoch 97, Train Loss: 0.015101557916128322, Val Loss: 0.016093250461367527
Epoch 98, Train Loss: 0.01503138992775, Val Loss: 0.016338717831826922
Epoch 99, Train Loss: 0.015078757967550361, Val Loss: 0.016478037350435754
Epoch 100, Train Loss: 0.014985626251503611, Val Loss: 0.01633207424919107
Epoch 101, Train Loss: 0.014759322786570023, Val Loss: 0.01683194490511026
Epoch 102, Train Loss: 0.014856852341496774, Val Loss: 0.016027600129148854
Epoch 103, Train Loss: 0.014765939864655289, Val Loss: 0.016350745793376396
Epoch 104, Train Loss: 0.01478316887330852, Val Loss: 0.016033862258738547
Epoch 105, Train Loss: 0.014725807853684755, Val Loss: 0.015603851276769568
Epoch 106, Train Loss: 0.014806732724746021, Val Loss: 0.015736672651967896
Epoch 107, Train Loss: 0.014543344516253642, Val Loss: 0.015925641963953404
Epoch 108, Train Loss: 0.014782626121683696, Val Loss: 0.016552887453850525
Epoch 109, Train Loss: 0.014329457426060472, Val Loss: 0.01566976616020078
Epoch 110, Train Loss: 0.014614671502155408, Val Loss: 0.016271342245389276
Epoch 111, Train Loss: 0.014544662480291567, Val Loss: 0.01549402935736215
Epoch 112, Train Loss: 0.01446673739478705, Val Loss: 0.015960639662373422
Epoch 113, Train Loss: 0.014492520645849015, Val Loss: 0.015249295007270663
Epoch 114, Train Loss: 0.014440985597028402, Val Loss: 0.01671606713711326
Epoch 115, Train Loss: 0.014369557464593336, Val Loss: 0.016106587264742424
Epoch 116, Train Loss: 0.01432103816972395, Val Loss: 0.015263923374352171
Epoch 117, Train Loss: 0.014226941607945987, Val Loss: 0.015028324297893404
Epoch 118, Train Loss: 0.01423997960485625, Val Loss: 0.014743029529145404
Epoch 119, Train Loss: 0.014351020645100677, Val Loss: 0.01581134552608675
Epoch 120, Train Loss: 0.014202667741131696, Val Loss: 0.015378265266320598
Epoch 121, Train Loss: 0.013911791727321142, Val Loss: 0.01487369868737548
Epoch 122, Train Loss: 0.013906272411186017, Val Loss: 0.01551159023682573
Epoch 123, Train Loss: 0.013943794016329723, Val Loss: 0.015357211718697156
Epoch 124, Train Loss: 0.01389588224233694, Val Loss: 0.015303193772239472
Epoch 125, Train Loss: 0.014016644986854359, Val Loss: 0.014799274629287755
Epoch 126, Train Loss: 0.013944415422379258, Val Loss: 0.014797273328277603
Epoch 127, Train Loss: 0.013957360926480812, Val Loss: 0.014890457517397938
Epoch 128, Train Loss: 0.013801010133939211, Val Loss: 0.015028401750570802
Epoch 129, Train Loss: 0.013806760874821952, Val Loss: 0.016021162049094245
Epoch 130, Train Loss: 0.014049455859925616, Val Loss: 0.015217644565585834
Epoch 131, Train Loss: 0.013769885206497029, Val Loss: 0.015085379940582745
Epoch 132, Train Loss: 0.013684874973103903, Val Loss: 0.014550712029102133
Epoch 133, Train Loss: 0.013696547392666625, Val Loss: 0.014757407259251645
Epoch 134, Train Loss: 0.01369966242827796, Val Loss: 0.014638274657859732
Epoch 135, Train Loss: 0.013533816318602511, Val Loss: 0.014734907506673193
Epoch 136, Train Loss: 0.013603145677738926, Val Loss: 0.014580759831440093
Epoch 137, Train Loss: 0.013541612814238482, Val Loss: 0.01570955854354065
Epoch 138, Train Loss: 0.013723757467789656, Val Loss: 0.016205344780056335
Epoch 139, Train Loss: 0.013546007516031916, Val Loss: 0.0152104031572591
Epoch 140, Train Loss: 0.013532601969771123, Val Loss: 0.015342667142846692
Epoch 141, Train Loss: 0.013450533512569786, Val Loss: 0.014644546336980898
Epoch 142, Train Loss: 0.013607010434706959, Val Loss: 0.014687455078559135
Epoch 143, Train Loss: 0.013542775672962934, Val Loss: 0.014521264234807953
Epoch 144, Train Loss: 0.013417973114026078, Val Loss: 0.014601941859877822
Epoch 145, Train Loss: 0.013331704906691489, Val Loss: 0.01485029947179467
Epoch 146, Train Loss: 0.013418046530318316, Val Loss: 0.014630124362102195
Epoch 147, Train Loss: 0.013351045589020663, Val Loss: 0.01494142015589707
Epoch 148, Train Loss: 0.013260266191045348, Val Loss: 0.015414885175761893
Epoch 149, Train Loss: 0.013240087648149598, Val Loss: 0.014419331771335494
Epoch 150, Train Loss: 0.01334052808297593, Val Loss: 0.01435606328965123
Test Loss: 0.008245683658557634
In [21]:
tr_ind = list(range(len(train_losses)))
val_ind = list(range(len(val_losses)))
plt.plot(train_losses[1:], label='train_loss')
plt.plot(val_losses[1:], label='val_loss')
plt.legend(loc='best')
Out[21]:
<matplotlib.legend.Legend at 0x7fb5a9a95fa0>
No description has been provided for this image
In [22]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [23]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [28]:
eva_list_frame = list()
device = 'cpu'
model = model.to(device)
best_mape = 1
best_img = None
best_mask = None
best_recov = None
test_miss_counts = list()
with torch.no_grad():
    for batch_idx, (X, y, mask, r) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        test_miss_counts.append(r)
        reconstructed = model(X)
        rev_data = y * max_pixel_value
        rev_recon = reconstructed * max_pixel_value
        # todo: 这里需要只评估修补出来的模块
        for i, sample in enumerate(rev_data):
            used_mask = mask_rev[i]
            data_label = sample[0] * used_mask
            recon_no2 = rev_recon[i][0] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
            if mape < best_mape:
                best_recov = rev_recon[i][0].numpy()
                best_mask = used_mask.numpy()
                best_img = sample[0].numpy()
                best_mape = mape
In [25]:
import pandas as pd
In [30]:
pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()
Out[30]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 1.264791 1.798069 0.161384 0.680643 0.889222 0.836726
std 0.601222 0.894735 0.092427 0.227477 0.104377 0.122876
min 0.377890 0.487859 0.045982 -2.265916 -0.146766 0.002855
25% 0.831340 1.149141 0.110199 0.579173 0.859047 0.785617
50% 1.126114 1.609603 0.142398 0.736236 0.922370 0.869874
75% 1.541714 2.221009 0.185216 0.840757 0.955571 0.922865
max 4.765854 8.694316 1.285374 0.988738 0.997125 0.994878
In [31]:
train_counts_int = [int(y) for x in train_counts for y in x]
val_counts_int = [int(y) for x in val_counts for y in x]
test_counts_int = [int(y) for x in test_miss_counts for y in x]
In [32]:
len(train_counts_int)
Out[32]:
26749
In [33]:
from collections import Counter
In [34]:
counts_train = Counter(train_counts_int)
counts_valid = Counter(val_counts_int)
counts_test = Counter(test_counts_int)
In [35]:
counts_df_train = pd.DataFrame.from_dict(dict(counts_train), orient='index').sort_index()
counts_df_test = pd.DataFrame.from_dict(dict(counts_test), orient='index').sort_index()
counts_df_valid = pd.DataFrame.from_dict(dict(counts_valid), orient='index').sort_index()
In [36]:
rst = pd.concat([counts_df_train, counts_df_valid, counts_df_test], axis=1)
In [37]:
rst.columns = ['train', 'validation', 'test']
In [56]:
rst.to_csv('./mix_eva.csv', index=False, encoding='utf-8-sig')
In [57]:
rst
Out[57]:
train validation test
10 9624 1500 1743
20 6534 1117 1150
30 5380 840 956
40 5211 818 890
In [53]:
 
In [ ]:
plt.figure(figsize=(16, 9))
rst.plot.bar()
plt.xlabel('Missing Rate(%)', fontsize=16, fontproperties='Times New Roman')
plt.ylabel('Sample Counts', fontsize=16, fontproperties='Times New Roman')
plt.xticks(rotation=45, fontproperties='Times New Roman')
plt.tight_layout()
plt.legend(loc='best', fontsize=16)
plt.savefig('./miss_counts.png')
In [ ]: