MAE_ATMO/build_gan-1d.ipynb

23 KiB
Raw Permalink Blame History

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICE"] = "0" 


# 设置CUDA设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda
In [2]:
max_pixel_value = 107.49169921875

print(f"Maximum pixel value in the dataset: {max_pixel_value}")
Maximum pixel value in the dataset: 107.49169921875
In [8]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')]  # 仅加载 .jpg 文件
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = np.random.choice(self.mask_filenames)
        mask_path = os.path.join(self.mask_dir, mask_idx)

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value  # 形状为 (96, 96, 1)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = masked_image[:, :, :1]  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (1, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

# 实例化数据集和数据加载器
train_dir = './out_mat/96/train/'
valid_dir = './out_mat/96/valid/'
test_dir = './out_mat/96/test/'
mask_dir = './out_mat/96/mask/20/'

print(f"checkpoint before Generator is OK")

dataset = NO2Dataset(train_dir, mask_dir)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

validset = NO2Dataset(valid_dir, mask_dir)
val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)

testset = NO2Dataset(test_dir, mask_dir)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)
checkpoint before Generator is OK
In [16]:
# 生成器模型
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x, mask):
        x_encoded = self.encoder(x)
        x_decoded = self.decoder(x_encoded)

#         x_decoded = (x_decoded + 1) / 2

#         x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]
        return x_decoded

# 判别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

# 将模型加载到GPU
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 定义优化器和损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss().to(device)

# 确认模型是否在GPU上
print(f"Generator is on: {next(generator.parameters()).device}")
print(f"Discriminator is on: {next(discriminator.parameters()).device}")
Generator is on: cpu
Discriminator is on: cpu
In [17]:
def masked_mse_loss(preds, target, mask):
    loss = (preds - target) ** 2
    loss = loss.mean(dim=-1)  # 对每个像素点求平均
    loss = (loss * (1-mask)).sum() / (1-mask).sum()  # 只计算被mask的像素点的损失
    return loss
In [18]:
generator.load_state_dict(torch.load('./models/GAN/generator-1d.pth'))
Out[18]:
<All keys matched successfully>
In [ ]:
# 开始训练
epochs = 100
for epoch in range(epochs):
    for i, (X, y, mask) in enumerate(train_loader):
        # 将数据移到 GPU 上
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        # print(f"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}") #checkpoint
        
        valid = torch.ones((X.size(0), 1, 12, 12)).to(device)
        fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)

        # 生成器生成图像
        optimizer_G.zero_grad()
        generated_images = generator(X, mask)
        g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * masked_mse_loss(
            generated_images, y, mask)
        g_loss.backward()
        optimizer_G.step()

        # 判别器训练
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)
        fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# 保存训练好的模型
torch.save(generator.state_dict(), './models/GAN/generator-1d.pth')
torch.save(discriminator.state_dict(), './models/GAN/discriminator-1d.pth')
In [10]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [11]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [19]:
eva_list = list()
device = 'cpu'
generator = generator.to(device)
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = generator(X, mask)
        rev_data = torch.squeeze(y * max_pixel_value, dim=1)
        rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)
        # todo: 这里需要只评估修补出来的模块
        data_label = rev_data * mask_rev
        data_label = data_label[mask_rev==1]
        recon_no2 = rev_recon * mask_rev
        recon_no2 = recon_no2[mask_rev==1]
        y_true = rev_data.flatten()
        y_pred = rev_recon.flatten()
        mae = mean_absolute_error(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mape = mean_absolute_percentage_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
        eva_list.append([mae, rmse, mape, r2, ioa])
In [14]:
import pandas as pd
In [24]:
pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()
Out[24]:
mae rmse mape r2 ioa
count 75.000000 75.000000 75.000000 75.000000 75.000000
mean 1.685609 2.824579 0.223852 0.807483 0.894409
std 0.520285 0.613299 0.066827 0.107566 0.024969
min 1.108756 2.040964 0.143461 0.336193 0.812887
25% 1.338143 2.462648 0.176170 0.780906 0.883027
50% 1.509821 2.608227 0.206274 0.850417 0.900165
75% 1.963103 3.067560 0.257667 0.866705 0.910917
max 3.729434 5.363288 0.461465 0.912240 0.935183
In [106]:
rst = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape', ascending=True)
In [140]:
rst.head(8)
Out[140]:
mae rmse mape r2 ioa r
3523 1.834664 2.579296 0.103911 0.855056 0.960841 0.931334
3544 1.500194 1.962885 0.106816 0.849688 0.960970 0.924187
1952 1.786639 2.290560 0.109383 0.704122 0.928829 0.869446
602 2.222957 2.934734 0.112751 0.735178 0.933639 0.877028
3531 2.093165 2.726698 0.115755 0.760530 0.937662 0.889606
1114 1.951748 2.591448 0.116578 0.696970 0.914501 0.843026
1979 2.083001 2.686231 0.116762 0.597512 0.886877 0.791842
2568 2.630587 3.636890 0.117044 0.491952 0.893928 0.833221
In [141]:
find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])
find_ex
Out[141]:
{'1114', '1952', '2568', '3523', '602'}
In [159]:
for j in find_ex:
    ori = np.load(f'./test_img/{j}-real.npy')[0]
    pred = np.load(f'./test_img/{j}-gan-recom.npy')[0]
    mask = np.load(f'./test_img/{j}-mask.npy')
    plt.imshow(ori, cmap='RdYlGn_r')
    plt.gca().axis('off')
    plt.savefig(f'./test_img/out_fig/{j}-truth.png', bbox_inches='tight')
    plt.clf()
    
    plt.imshow(mask, cmap='gray')
    plt.gca().axis('off')
    plt.savefig(f'./test_img/out_fig/{j}-mask.png', bbox_inches='tight')
    plt.clf()
    
    mask_cp = np.where((1-mask) == 0, np.nan, (1-mask))
    plt.imshow(ori * mask_cp, cmap='RdYlGn_r')
    plt.gca().axis('off')
    plt.savefig(f'./test_img/out_fig/{j}-masked_ori.png', bbox_inches='tight')
    plt.clf()
    
    out = ori * mask + pred * (1 - mask)
    plt.imshow(out, cmap='RdYlGn_r')
    plt.gca().axis('off')
    plt.savefig(f'./test_img/out_fig/{j}-gan_out.png', bbox_inches='tight')
    plt.clf()
<Figure size 640x480 with 0 Axes>
In [ ]: