MAE_ATMO/build_gan.ipynb

122 KiB
Raw Permalink Blame History

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICE"] = "0" 


# 设置CUDA设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda
In [2]:
max_pixel_value = 107.49169921875

print(f"Maximum pixel value in the dataset: {max_pixel_value}")
Maximum pixel value in the dataset: 107.49169921875
In [3]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')]  # 仅加载 .jpg 文件
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = idx % len(self.mask_filenames)
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32) / max_pixel_value  # 形状为 (96, 96, 8)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = np.concatenate([masked_image[:, :, :1], image[:, :, 1:]], axis=-1)  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (8, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

# 实例化数据集和数据加载器
train_dir = './out_mat/96/train/'
valid_dir = './out_mat/96/valid/'
test_dir = './out_mat/96/test/'
mask_dir = './out_mat/96/mask/20/'

print(f"checkpoint before Generator is OK")

dataset = NO2Dataset(train_dir, mask_dir)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

validset = NO2Dataset(valid_dir, mask_dir)
val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)

testset = NO2Dataset(test_dir, mask_dir)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)
checkpoint before Generator is OK
In [4]:
# 生成器模型
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(8, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x_encoded = self.encoder(x)
        x_decoded = self.decoder(x_encoded)

#         x_decoded = (x_decoded + 1) / 2

#         x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]
        return x_output

# 判别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

# 将模型加载到GPU
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 定义优化器和损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss().to(device)
pixelwise_loss = nn.MSELoss().to(device)

# 确认模型是否在GPU上
print(f"Generator is on: {next(generator.parameters()).device}")
print(f"Discriminator is on: {next(discriminator.parameters()).device}")
Generator is on: cuda:0
Discriminator is on: cuda:0
In [ ]:
# 开始训练
epochs = 150
for epoch in range(epochs):
    for i, (X, y, mask) in enumerate(train_loader):
        # 将数据移到 GPU 上
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        # print(f"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}") #checkpoint
        
        valid = torch.ones((X.size(0), 1, 12, 12)).to(device)
        fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)

        # 生成器生成图像
        optimizer_G.zero_grad()
        generated_images = generator(X, mask)
        g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * pixelwise_loss(
            generated_images, y)
        g_loss.backward()
        optimizer_G.step()

        # 判别器训练
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)
        fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# 保存训练好的模型
torch.save(generator.state_dict(), './models/GAN/generator.pth')
torch.save(discriminator.state_dict(), './models/GAN/discriminator.pth')
In [9]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [10]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [11]:
eva_list = list()
device = 'cpu'
generator = generator.to(device)
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = generator(X, mask)
        rev_data = torch.squeeze(y * max_pixel_value, dim=1)
        rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)
        # todo: 这里需要只评估修补出来的模块
        data_label = rev_data * mask_rev
        data_label = data_label[mask_rev==1]
        recon_no2 = rev_recon * mask_rev
        recon_no2 = recon_no2[mask_rev==1]
        y_true = rev_data.flatten()
        y_pred = rev_recon.flatten()
        mae = mean_absolute_error(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mape = mean_absolute_percentage_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
        eva_list.append([mae, rmse, mape, r2, ioa])
In [22]:
pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()
Out[22]:
mae rmse mape r2 ioa
count 75.000000 75.000000 75.000000 75.000000 75.000000
mean 0.399366 1.246761 0.047188 0.963991 0.939587
std 0.071295 0.220616 0.005035 0.018081 0.026807
min 0.348072 1.073966 0.040716 0.813181 0.719442
25% 0.372762 1.154684 0.043751 0.963354 0.938713
50% 0.388768 1.207949 0.045860 0.966356 0.943430
75% 0.402351 1.274836 0.050051 0.968919 0.947026
max 0.959251 2.966476 0.066256 0.972840 0.956998
In [23]:
eva_list_frame = list()
best_mape = 1
best_img = None
best_mask = None
best_recov = None
with torch.no_grad():
    for batch_idx, (X, y, mask) in enumerate(test_loader):
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域
        reconstructed = generator(X, mask)
        rev_data = y * max_pixel_value
        rev_recon = reconstructed * max_pixel_value
        # todo: 这里需要只评估修补出来的模块
        for i, sample in enumerate(rev_data):
            used_mask = mask_rev[i]
            data_label = sample[0] * used_mask
            recon_no2 = rev_recon[i][0] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
            if mape < best_mape:
                best_recov = rev_recon[i][0].numpy()
                best_mask = used_mask.numpy()
                best_img = sample[0].numpy()
                best_mape = mape
In [24]:
pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()
Out[24]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 1.977856 2.505394 0.236766 0.444234 0.826185 0.795505
std 0.966137 1.156947 0.075139 0.309037 0.112186 0.109227
min 0.588599 0.782554 0.106112 -5.779783 -2.754070 0.284676
25% 1.195541 1.551567 0.187231 0.300401 0.781712 0.735376
50% 1.606092 2.094027 0.220013 0.506733 0.849549 0.822590
75% 2.658243 3.338708 0.266574 0.658528 0.899010 0.876813
max 9.428754 9.982598 0.903847 0.889351 0.969285 0.960868
In [26]:
real_test = NO2Dataset('./out_mat/96/test/', mask_dir)
real_loader = DataLoader(real_test, batch_size=1, shuffle=True, num_workers=4)
In [28]:
for batch_idx, (X, y, mask) in enumerate(real_loader):
    print(X.shape, y.shape, mask.shape)
    np.save(f'./test_img/{batch_idx}-img.npy', X[0])
    np.save(f'./test_img/{batch_idx}-mask.npy', mask[0])
    np.save(f'./test_img/{batch_idx}-real.npy', y[0])
    if batch_idx >=4:
        break
torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])
torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])
torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])
torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])
torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])
In [72]:
test_imgs = [x for x in os.listdir('./test_img/') if 'img' in x]
test_imgs.sort()
test_masks = [x for x in os.listdir('./test_img/') if 'mask' in x]
test_masks.sort()
for img_npy, mask_npy in zip(test_imgs, test_masks):
    img = np.load(f'./test_img/{img_npy}')
    img_in = torch.tensor(np.expand_dims(img, 0), dtype=torch.float32)
    mask = np.load(f'./test_img/{mask_npy}')
    mask_in = torch.tensor(np.expand_dims(mask, 0), dtype=torch.float32)
    out = generator(img_in, mask_in).detach().cpu().numpy() * max_pixel_value
In [73]:
import matplotlib.pyplot as plt
In [74]:
plt.imshow(out[0][0], cmap='RdYlGn_r')
Out[74]:
<matplotlib.image.AxesImage at 0x7f65196d8940>
No description has been provided for this image
In [75]:
test_real = [x for x in os.listdir('./test_img/') if 'real' in x]
test_real.sort()
In [77]:
y_real = np.load(f'./test_img/{test_real[4]}')*max_pixel_value
In [78]:
plt.imshow(y_real[0], cmap='RdYlGn_r')
Out[78]:
<matplotlib.image.AxesImage at 0x7f65196b6370>
No description has been provided for this image
In [79]:
(y_real[0] * mask[0] + out[0][0] * (1-mask[0])).shape
Out[79]:
(96, 96)
In [80]:
d = y_real[0] * mask[0] + out[0][0] * (1-mask[0])
In [81]:
plt.imshow(out[0][0] * (1-mask[0]), cmap='RdYlGn_r')
Out[81]:
<matplotlib.image.AxesImage at 0x7f65196220a0>
No description has been provided for this image
In [ ]: