MAE_ATMO/torch_MAE_1d_ViT-Copy1.ipynb

44 KiB
Raw Permalink Blame History

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
import pandas as pd
import os
from PIL import Image

MAX_VALUE = 107.49169921875
In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
Out[2]:
device(type='cuda')
In [3]:
class GrayScaleDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        data = np.load(file_path)[:,:,0] / MAX_VALUE
        return torch.tensor(data, dtype=torch.float32).unsqueeze(0)
 
In [4]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')]  # 仅加载 .jpg 文件
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = idx % len(self.mask_filenames)
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE  # 形状为 (96, 96, 1)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = masked_image[:, :, :1]  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (1, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)
In [5]:
class PatchMasking:
    def __init__(self, patch_size, mask_ratio):
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio

    def __call__(self, x):
        batch_size, C, H, W = x.shape
        num_patches = (H // self.patch_size) * (W // self.patch_size)
        num_masked = int(num_patches * self.mask_ratio)
        
        # 为每个样本生成独立的mask
        masks = []
        for _ in range(batch_size):
            mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)
            mask[:num_masked] = 1
            mask = mask[torch.randperm(num_patches)]
            mask = mask.view(H // self.patch_size, W // self.patch_size)
            mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)
            masks.append(mask)
        
        # 将所有mask堆叠成一个批量张量
        masks = torch.stack(masks, dim=0)
        masks = torch.unsqueeze(masks, dim=1)
        
        # 应用mask到输入x上
        masked_x = x * (1 - masks.float())
        return masked_x, masks
In [6]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.act = nn.GELU()  # 使用 GELU 激活函数
        self.fc2 = nn.Linear(output_dim, input_dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class Attention(nn.Module):
    def __init__(self, dim, heads):
        super(Attention, self).__init__()
        self.heads = heads
        self.dim = dim
        self.scale = dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(0.1)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(0.1)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj_drop(self.proj(out))

class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.ln = nn.LayerNorm(dim, eps=eps)

    def forward(self, x):
        return self.ln(x)

class Dropout(nn.Module):
    def __init__(self, p=0.1):
        super(Dropout, self).__init__()
        self.dropout = nn.Dropout(p)

    def forward(self, x):
        return self.dropout(x)

class ViTEncoder(nn.Module):
    def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256, dropout=0.1):
        super(ViTEncoder, self).__init__()
        self.patch_size = patch_size
        self.dim = dim
        self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)

        self.attention_layers = nn.ModuleList([
            nn.ModuleList([
                LayerNorm(dim),  # Layer Normalization
                Attention(dim, heads),
                Dropout(dropout),  # Dropout
                LayerNorm(dim),  # Layer Normalization
                MLP(dim, mlp_dim),
                Dropout(dropout)  # Dropout
            ]) for _ in range(depth)
        ])

    def forward(self, x):
        x = self.patch_embedding(x)  # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)
        x = x.flatten(2).transpose(1, 2)  # 形状变为 (batch_size, num_patches, dim)

        for norm1, attn, drop1, norm2, mlp, drop2 in self.attention_layers:
            x = x + drop1(attn(norm1(x)))  # 残差连接
            x = x + drop2(mlp(norm2(x)))    # 残差连接
        return x


class ConvDecoder(nn.Module):
    def __init__(self, dim=128, patch_size=8, img_size=96):
        super(ConvDecoder, self).__init__()
        self.dim = dim
        self.patch_size = patch_size
        self.img_size = img_size
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
                        
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 修改为 output_padding=1
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)
        x = self.decoder(x)
        return x

class MAEModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(MAEModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        # self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
In [7]:
def masked_mse_loss(preds, target, mask):
    loss = (preds - target) ** 2
    loss = loss.mean(dim=-1)  # 对每个像素点求平均
    loss = (loss * mask).sum() / mask.sum()  # 只计算被mask的像素点的损失
    return loss
In [8]:
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):
    best_model = model
    best_loss = 100
    model.to(device)
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
            output = model(masked_data)
            loss = masked_mse_loss(output, data, mask)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
                output = model(masked_data)
                loss = masked_mse_loss(output, data, mask)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        if val_loss < best_loss:
            best_loss = val_loss
            best_model = model

        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
In [9]:
train_dir = './out_mat/96/train/'
train_dataset = GrayScaleDataset(train_dir)

val_dir = './out_mat/96/valid/'
val_dataset = GrayScaleDataset(val_dir)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
In [10]:
import matplotlib.pyplot as plt
In [43]:
encoder = ViTEncoder()
decoder = ConvDecoder()
model = MAEModel(encoder, decoder)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
In [ ]:
train_model(model, train_loader, val_loader, epochs=50, criterion=criterion, optimizer=optimizer, device=device)
Epoch 1, Train Loss: 0.9251, Val Loss: 0.0869
Epoch 2, Train Loss: 0.0734, Val Loss: 0.0506
Epoch 3, Train Loss: 0.0494, Val Loss: 0.0489
Epoch 4, Train Loss: 0.0432, Val Loss: 0.0462
Epoch 5, Train Loss: 0.0390, Val Loss: 0.0400
Epoch 6, Train Loss: 0.0351, Val Loss: 0.0356
In [18]:
test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)
In [19]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [20]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [24]:
test_set2 = GrayScaleDataset('./out_mat/96/test/')
test_loader2 = DataLoader(test_set2, batch_size=64, shuffle=False, num_workers=4)
In [34]:
rev_data.shape
Out[34]:
torch.Size([64, 96, 96])
In [35]:
eva_list = list()
device = 'cpu'
model = model.to(device)
with torch.no_grad():
    for data in test_loader2:
        data = data.to(device)
        masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
        output = model(masked_data)
        rev_data = data * MAX_VALUE
        rev_recon = output * MAX_VALUE
        data_label = rev_data * mask
        data_label = data_label[mask==1]
        recon_no2 = rev_recon * mask
        recon_no2 = recon_no2[mask==1]
        y_true = rev_data.flatten()
        y_pred = rev_recon.flatten()
        mae = mean_absolute_error(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mape = mean_absolute_percentage_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
        eva_list.append([mae, rmse, mape, r2, ioa])
In [36]:
pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()
Out[36]:
mae rmse mape r2 ioa
count 75.000000 75.000000 75.000000 75.000000 75.000000
mean 1.208013 1.600644 0.142720 0.941983 0.981683
std 0.056235 0.081791 0.003435 0.004449 0.002309
min 1.091517 1.446389 0.134849 0.911833 0.965708
25% 1.170305 1.555051 0.140519 0.940425 0.981100
50% 1.204728 1.593261 0.142981 0.942651 0.982003
75% 1.242762 1.646311 0.145185 0.944118 0.982809
max 1.420721 2.037903 0.150566 0.949663 0.984610
In [38]:
eva_list_frame = list()
device = 'cpu'
model = model.to(device)
with torch.no_grad():
    for data in test_loader2:
        data = data.to(device)
        masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)
        output = model(masked_data)
        rev_data = data * MAX_VALUE
        rev_recon = output * MAX_VALUE
        # todo: 这里需要只评估修补出来的模块
        for i, sample in enumerate(rev_data):
            used_mask = mask[i]
            data_label = sample[0] * used_mask
            recon_no2 = rev_recon[i][0] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
In [39]:
eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')
eva_frame_df.describe()
Out[39]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 1.714315 2.350189 0.215974 0.609470 0.943560 0.823401
std 0.697344 0.940345 0.077893 0.131496 0.022261 0.069394
min 0.636049 0.821723 0.099999 0.003194 0.802237 0.405363
25% 1.121617 1.576669 0.170974 0.533081 0.931653 0.783616
50% 1.459720 2.132316 0.199419 0.623769 0.946952 0.831403
75% 2.334761 3.119393 0.234517 0.698517 0.958943 0.872422
max 4.406258 8.470109 1.242636 0.895199 0.986901 0.965110
In [ ]:
eva_list = list()
device = 'cpu'
model = model.to(device)
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader2):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = model(X)
        rev_data = torch.squeeze(y * MAX_VALUE, dim=1)
        rev_recon = torch.squeeze(reconstructed * MAX_VALUE, dim=1)
        # todo: 这里需要只评估修补出来的模块
        data_label = rev_data * mask_rev
        data_label = data_label[mask_rev==1]
        recon_no2 = rev_recon * mask_rev
        recon_no2 = recon_no2[mask_rev==1]
        y_true = rev_data.flatten()
        y_pred = rev_recon.flatten()
        mae = mean_absolute_error(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mape = mean_absolute_percentage_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
        eva_list.append([mae, rmse, mape, r2, ioa])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[25], line 5
      3 model = model.to(device)
      4 with torch.no_grad():
----> 5     for batch_idx, (X, y, mask) in enumerate(test_loader2):
      6         X, y, mask = X.to(device), y.to(device), mask.to(device)
      7         mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域

ValueError: too many values to unpack (expected 3)
In [23]:
pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()
Out[23]:
mae rmse mape r2 ioa
count 149.000000 149.000000 149.000000 149.000000 149.000000
mean 2.235662 4.042349 0.238494 0.626060 0.572341
std 0.192709 0.357475 0.007405 0.042890 0.042652
min 1.786567 3.167143 0.224796 0.522157 0.460707
25% 2.084117 3.779276 0.232974 0.597774 0.547144
50% 2.226062 4.075465 0.237429 0.627588 0.570579
75% 2.361411 4.284523 0.243866 0.656226 0.601233
max 2.751377 4.917407 0.258230 0.740943 0.666083
In [23]:
eva_list_frame = list()
device = 'cpu'
model = model.to(device)
best_mape = 1
best_img = None
best_mask = None
best_recov = None
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = model(X)
        rev_data = y * MAX_VALUE
        rev_recon = reconstructed * MAX_VALUE
        # todo: 这里需要只评估修补出来的模块
        for i, sample in enumerate(rev_data):
            used_mask = mask_rev[i]
            data_label = sample[0] * used_mask
            recon_no2 = rev_recon[i][0] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
In [ ]:
eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')
In [28]:
eva_frame_df.describe()
Out[28]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 5.920017 6.864245 0.603656 -2.743017 0.228580 0.225978
std 3.534648 3.845034 0.224679 2.049753 0.370622 0.227965
min 1.477380 1.849392 0.271934 -22.827546 -1.899284 -0.626938
25% 2.975700 3.600521 0.502338 -3.631702 0.042875 0.088760
50% 4.169098 5.055890 0.558942 -2.233530 0.309592 0.253954
75% 8.616798 9.809069 0.632651 -1.287602 0.509937 0.389390
max 18.840775 20.371025 3.689853 0.024294 0.835339 0.782481
In [31]:
# 可视化特定特征的函数
def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 4, 1)
    plt.imshow(input_feature, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 2)
    plt.imshow(masked_feature, cmap='gray')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 3)
    plt.imshow(recov_region, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.subplot(1, 4, 4)
    plt.imshow(output_feature, cmap='RdYlGn_r')
    plt.gca().axis('off')  # 获取当前坐标轴并关闭
    plt.savefig('./figures/result/vitmae_20_samples.png', bbox_inches='tight')
In [92]:
torch.save(model, './models/MAE/vit.pt')
In [41]:
find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])
find_ex
Out[41]:
{'1114', '1952', '2568', '3523', '602'}
In [42]:
for j in find_ex:
    ori = np.load(f'./test_img/{j}-real.npy')[0]
    mask = np.load(f'./test_img/{j}-mask.npy')
    mask_rev = 1 - mask
    img_in = ori * mask_rev / MAX_VALUE
    img_out = model(torch.tensor(img_in.reshape(1, 1, 96, 96), dtype=torch.float32)).detach().cpu().numpy()[0][0] * MAX_VALUE
    out = ori * mask_rev + img_out * mask
    plt.imshow(out, cmap='RdYlGn_r')
    plt.gca().axis('off')
    plt.savefig(f'./test_img/out_fig/{j}-mae_vit_out.png', bbox_inches='tight')
    plt.clf()
<Figure size 640x480 with 0 Axes>
In [ ]: