MAE_ATMO/torch_GAN_1d_baseline.ipynb

93 KiB
Raw Permalink Blame History

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICE"] = "0" 
In [3]:
# 设置CUDA设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda
In [12]:
# 定义函数来找到最大值
def find_max_pixel_value(image_dir):
    max_pixel_value = 0.0
    for filename in os.listdir(image_dir):
        if filename.endswith('.npy'):
            image_path = os.path.join(image_dir, filename)
            image = np.load(image_path).astype(np.float32)
            max_pixel_value = max(max_pixel_value, image[:, :, 0].max())
    return max_pixel_value

# 计算图像数据中的最大像素值
image_dir = './out_mat/96/train/' 
max_pixel_value = 107.49169921875

print(f"Maximum pixel value in the dataset: {max_pixel_value}")
Maximum pixel value in the dataset: 107.49169921875
In [5]:
np.random.seed(42)
torch.random.manual_seed(42)
Out[5]:
<torch._C.Generator at 0x7f7a40059f90>
In [6]:
class NO2Dataset(Dataset):
    
    def __init__(self, image_dir, mask_dir):
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')]  # 仅加载 .npy 文件
        self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')]  # 仅加载 .jpg 文件
        
    def __len__(self):
        
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_idx = np.random.choice(self.mask_filenames)
        mask_path = os.path.join(self.mask_dir, mask_idx)

        # 加载图像数据 (.npy 文件)
        image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value  # 形状为 (96, 96, 8)

        # 加载掩码数据 (.jpg 文件)
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)

        # 将掩码数据中非0值设为10值保持不变
        mask = np.where(mask != 0, 1.0, 0.0)

        # 保持掩码数据形状为 (96, 96, 1)
        mask = mask[:, :, np.newaxis]  # 将形状调整为 (96, 96, 1)

        # 应用掩码
        masked_image = image.copy()
        masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze()  # 遮盖NO2数据

        # cGAN的输入和目标
        X = np.concatenate([masked_image[:, :, :1], image[:, :, 1:]], axis=-1)  # 形状为 (96, 96, 8)
        y = image[:, :, 0:1]  # 目标输出为NO2数据形状为 (96, 96, 1)

        # 转换形状为 (channels, height, width)
        X = np.transpose(X, (2, 0, 1))  # 转换为 (8, 96, 96)
        y = np.transpose(y, (2, 0, 1))  # 转换为 (1, 96, 96)
        mask = np.transpose(mask, (2, 0, 1))  # 转换为 (1, 96, 96)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

# 实例化数据集和数据加载器
image_dir = './out_mat/96/train/'
mask_dir = './out_mat/96/mask/20/'

print(f"checkpoint before Generator is OK")
checkpoint before Generator is OK
In [7]:
dataset = NO2Dataset(image_dir, mask_dir)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

# 生成器模型
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x, mask):
        x_encoded = self.encoder(x)
        x_decoded = self.decoder(x_encoded)

        x_decoded = (x_decoded + 1) / 2

        x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]
        return x_output

# 判别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

# 将模型加载到GPU
generator = Generator().to(device)
discriminator = Discriminator().to(device)
In [8]:
# 定义优化器和损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss().to(device)
pixelwise_loss = nn.MSELoss().to(device)

# 确认模型是否在GPU上
print(f"Generator is on: {next(generator.parameters()).device}")
print(f"Discriminator is on: {next(discriminator.parameters()).device}")
Generator is on: cuda:0
Discriminator is on: cuda:0
In [9]:
gen = torch.load('./models/GAN/generator.pth', map_location='cpu')
generator.load_state_dict(gen)
generator = generator.to('cpu')
In [10]:
dis = torch.load('./models/GAN/discriminator.pth', map_location='cpu')
discriminator.load_state_dict(dis)
Out[10]:
<All keys matched successfully>
In [10]:
# 开始训练
epochs = 300
for epoch in range(epochs):
    for i, (X, y, mask) in enumerate(dataloader):
        # 将数据移到 GPU 上
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        
        valid = torch.ones((X.size(0), 1, 12, 12)).to(device)
        fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)

        # 生成器生成图像
        optimizer_G.zero_grad()
        generated_images = generator(X, mask)
        g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * pixelwise_loss(
            generated_images, y)
        g_loss.backward()
        optimizer_G.step()

        # 判别器训练
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)
        fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# 保存训练好的模型
torch.save(generator.state_dict(), './models/GAN/generator.pth')
torch.save(discriminator.state_dict(), './models/GAN/discriminator.pth')
/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
Epoch [0/300] [D loss: 0.5392551422119141] [G loss: 0.9313964247703552]
Epoch [1/300] [D loss: 0.45492202043533325] [G loss: 1.5053317546844482]
Epoch [2/300] [D loss: 0.3420121669769287] [G loss: 1.4923288822174072]
Epoch [3/300] [D loss: 0.2960708737373352] [G loss: 1.955796480178833]
Epoch [4/300] [D loss: 0.40790891647338867] [G loss: 2.071624279022217]
Epoch [5/300] [D loss: 0.2747359275817871] [G loss: 1.917580485343933]
Epoch [6/300] [D loss: 0.47008591890335083] [G loss: 1.5003858804702759]
Epoch [7/300] [D loss: 0.19478999078273773] [G loss: 3.949864149093628]
Epoch [8/300] [D loss: 0.5340784192085266] [G loss: 0.8913870453834534]
Epoch [9/300] [D loss: 0.3194230794906616] [G loss: 1.861933946609497]
Epoch [10/300] [D loss: 0.22022968530654907] [G loss: 1.6654534339904785]
Epoch [11/300] [D loss: 0.3743482828140259] [G loss: 1.626413106918335]
Epoch [12/300] [D loss: 0.13774043321609497] [G loss: 3.187469482421875]
Epoch [13/300] [D loss: 0.4275822043418884] [G loss: 2.2269718647003174]
Epoch [14/300] [D loss: 0.22367843985557556] [G loss: 2.707667827606201]
Epoch [15/300] [D loss: 0.25350409746170044] [G loss: 1.6780041456222534]
Epoch [16/300] [D loss: 0.23647311329841614] [G loss: 1.6349072456359863]
Epoch [17/300] [D loss: 0.5604373812675476] [G loss: 2.025310754776001]
Epoch [18/300] [D loss: 0.4707084596157074] [G loss: 2.938746452331543]
Epoch [19/300] [D loss: 0.2135343998670578] [G loss: 1.438072919845581]
Epoch [20/300] [D loss: 0.08090662956237793] [G loss: 2.77827787399292]
Epoch [21/300] [D loss: 0.2995302677154541] [G loss: 2.239123821258545]
Epoch [22/300] [D loss: 0.45718979835510254] [G loss: 3.404195547103882]
Epoch [23/300] [D loss: 0.13950982689857483] [G loss: 2.7335846424102783]
Epoch [24/300] [D loss: 0.1814766675233841] [G loss: 2.7603776454925537]
Epoch [25/300] [D loss: 0.0551619790494442] [G loss: 4.066004276275635]
Epoch [26/300] [D loss: 0.1498052179813385] [G loss: 2.733922243118286]
Epoch [27/300] [D loss: 0.11236032098531723] [G loss: 3.215831756591797]
Epoch [28/300] [D loss: 0.4945942461490631] [G loss: 1.9661915302276611]
Epoch [29/300] [D loss: 0.0760776624083519] [G loss: 2.7212204933166504]
Epoch [30/300] [D loss: 0.19911707937717438] [G loss: 1.5356297492980957]
Epoch [31/300] [D loss: 0.0900304988026619] [G loss: 2.84740948677063]
Epoch [32/300] [D loss: 0.05910511314868927] [G loss: 4.071162223815918]
Epoch [33/300] [D loss: 0.0711229220032692] [G loss: 2.716848373413086]
Epoch [34/300] [D loss: 0.39897823333740234] [G loss: 1.5674937963485718]
Epoch [35/300] [D loss: 0.05552360787987709] [G loss: 3.754516124725342]
Epoch [36/300] [D loss: 0.5413599014282227] [G loss: 4.129798889160156]
Epoch [37/300] [D loss: 0.17664434015750885] [G loss: 4.226540565490723]
Epoch [38/300] [D loss: 0.15215317904949188] [G loss: 2.6023881435394287]
Epoch [39/300] [D loss: 0.07798739522695541] [G loss: 4.075980186462402]
Epoch [40/300] [D loss: 0.03936776518821716] [G loss: 4.37988805770874]
Epoch [41/300] [D loss: 0.2012120634317398] [G loss: 2.1987271308898926]
Epoch [42/300] [D loss: 0.05274203419685364] [G loss: 3.8458642959594727]
Epoch [43/300] [D loss: 0.13967157900333405] [G loss: 3.438344955444336]
Epoch [44/300] [D loss: 0.05800560116767883] [G loss: 2.941135883331299]
Epoch [45/300] [D loss: 0.14671097695827484] [G loss: 3.388277292251587]
Epoch [46/300] [D loss: 0.06439051032066345] [G loss: 2.9789438247680664]
Epoch [47/300] [D loss: 0.11101078987121582] [G loss: 2.6266937255859375]
Epoch [48/300] [D loss: 0.028554894030094147] [G loss: 4.042592525482178]
Epoch [49/300] [D loss: 0.3364626169204712] [G loss: 3.419842004776001]
Epoch [50/300] [D loss: 0.2501979470252991] [G loss: 3.319307804107666]
Epoch [51/300] [D loss: 0.2962917387485504] [G loss: 5.088353157043457]
Epoch [52/300] [D loss: 0.07700179517269135] [G loss: 3.231515884399414]
Epoch [53/300] [D loss: 0.4093267321586609] [G loss: 1.918235182762146]
Epoch [54/300] [D loss: 0.12105419486761093] [G loss: 2.3409922122955322]
Epoch [55/300] [D loss: 0.057456158101558685] [G loss: 4.047771453857422]
Epoch [56/300] [D loss: 0.250449538230896] [G loss: 1.9442336559295654]
Epoch [57/300] [D loss: 0.08125491440296173] [G loss: 2.7323458194732666]
Epoch [58/300] [D loss: 0.06671395897865295] [G loss: 3.081458330154419]
Epoch [59/300] [D loss: 0.06982511281967163] [G loss: 3.95278000831604]
Epoch [60/300] [D loss: 0.08973922580480576] [G loss: 3.9550158977508545]
Epoch [61/300] [D loss: 0.29226893186569214] [G loss: 2.2824535369873047]
Epoch [62/300] [D loss: 0.06800767779350281] [G loss: 4.67025089263916]
Epoch [63/300] [D loss: 0.017987174913287163] [G loss: 4.119121551513672]
Epoch [64/300] [D loss: 0.1278763711452484] [G loss: 4.481695652008057]
Epoch [65/300] [D loss: 0.12277506291866302] [G loss: 2.0188961029052734]
Epoch [66/300] [D loss: 0.10042040050029755] [G loss: 4.019499778747559]
Epoch [67/300] [D loss: 0.15092261135578156] [G loss: 3.0588033199310303]
Epoch [68/300] [D loss: 0.157196044921875] [G loss: 4.579256534576416]
Epoch [69/300] [D loss: 0.0256386436522007] [G loss: 4.309335708618164]
Epoch [70/300] [D loss: 0.011956267058849335] [G loss: 4.763312816619873]
Epoch [71/300] [D loss: 0.08460590243339539] [G loss: 5.456184387207031]
Epoch [72/300] [D loss: 0.07495025545358658] [G loss: 3.5078511238098145]
Epoch [73/300] [D loss: 0.13037167489528656] [G loss: 3.164292812347412]
Epoch [74/300] [D loss: 0.0830327719449997] [G loss: 5.159647464752197]
Epoch [75/300] [D loss: 0.4353921115398407] [G loss: 5.0652875900268555]
Epoch [76/300] [D loss: 0.02432486228644848] [G loss: 3.7066524028778076]
Epoch [77/300] [D loss: 0.2809848189353943] [G loss: 1.1604290008544922]
Epoch [78/300] [D loss: 0.7653636932373047] [G loss: 2.5745716094970703]
Epoch [79/300] [D loss: 0.041840165853500366] [G loss: 4.082228660583496]
Epoch [80/300] [D loss: 0.03992146998643875] [G loss: 4.9236321449279785]
Epoch [81/300] [D loss: 0.1003192886710167] [G loss: 2.683060646057129]
Epoch [82/300] [D loss: 0.1460535228252411] [G loss: 4.597597122192383]
Epoch [83/300] [D loss: 0.1408858597278595] [G loss: 1.8829160928726196]
Epoch [84/300] [D loss: 0.048089221119880676] [G loss: 3.1438090801239014]
Epoch [85/300] [D loss: 0.041934601962566376] [G loss: 3.298645257949829]
Epoch [86/300] [D loss: 0.1363355964422226] [G loss: 2.6124517917633057]
Epoch [87/300] [D loss: 0.03299988433718681] [G loss: 3.3402161598205566]
Epoch [88/300] [D loss: 0.22786922752857208] [G loss: 3.9778051376342773]
Epoch [89/300] [D loss: 0.021804900839924812] [G loss: 4.595890045166016]
Epoch [90/300] [D loss: 0.022495444864034653] [G loss: 4.2465901374816895]
Epoch [91/300] [D loss: 0.02908019907772541] [G loss: 6.379057884216309]
Epoch [92/300] [D loss: 0.6523040533065796] [G loss: 0.6009750962257385]
Epoch [93/300] [D loss: 0.007557982578873634] [G loss: 5.837783336639404]
Epoch [94/300] [D loss: 0.020063551142811775] [G loss: 4.044745445251465]
Epoch [95/300] [D loss: 0.003706925082951784] [G loss: 8.243224143981934]
Epoch [96/300] [D loss: 0.021942533552646637] [G loss: 4.662309169769287]
Epoch [97/300] [D loss: 0.005410192534327507] [G loss: 5.5743536949157715]
Epoch [98/300] [D loss: 0.07137680053710938] [G loss: 3.261455535888672]
Epoch [99/300] [D loss: 0.11327817291021347] [G loss: 3.817570686340332]
Epoch [100/300] [D loss: 0.04488084092736244] [G loss: 4.458094596862793]
Epoch [101/300] [D loss: 0.05757671222090721] [G loss: 3.695896625518799]
Epoch [102/300] [D loss: 0.04083157703280449] [G loss: 3.704172134399414]
Epoch [103/300] [D loss: 0.02816752716898918] [G loss: 4.322700023651123]
Epoch [104/300] [D loss: 0.026689285412430763] [G loss: 4.115890979766846]
Epoch [105/300] [D loss: 0.03571446239948273] [G loss: 4.080765724182129]
Epoch [106/300] [D loss: 0.020453810691833496] [G loss: 5.457651615142822]
Epoch [107/300] [D loss: 0.03774755448102951] [G loss: 5.34019136428833]
Epoch [108/300] [D loss: 0.0933525487780571] [G loss: 5.5797905921936035]
Epoch [109/300] [D loss: 0.024301748722791672] [G loss: 4.042290210723877]
Epoch [110/300] [D loss: 0.9034162759780884] [G loss: 5.52556848526001]
Epoch [111/300] [D loss: 0.0911281406879425] [G loss: 6.487083911895752]
Epoch [112/300] [D loss: 0.13892149925231934] [G loss: 3.0797510147094727]
Epoch [113/300] [D loss: 0.09627098590135574] [G loss: 3.104957103729248]
Epoch [114/300] [D loss: 0.007696065586060286] [G loss: 6.618851184844971]
Epoch [115/300] [D loss: 0.06528083980083466] [G loss: 3.4506514072418213]
Epoch [116/300] [D loss: 0.03879600390791893] [G loss: 3.3708789348602295]
Epoch [117/300] [D loss: 0.03395622968673706] [G loss: 6.2684736251831055]
Epoch [118/300] [D loss: 0.010569067671895027] [G loss: 5.944631099700928]
Epoch [119/300] [D loss: 0.024817001074552536] [G loss: 6.614266872406006]
Epoch [120/300] [D loss: 0.013173197396099567] [G loss: 6.226423263549805]
Epoch [121/300] [D loss: 0.06546411663293839] [G loss: 3.0585291385650635]
Epoch [122/300] [D loss: 0.01085597462952137] [G loss: 6.437295913696289]
Epoch [123/300] [D loss: 0.03522876650094986] [G loss: 4.0734052658081055]
Epoch [124/300] [D loss: 0.06875205039978027] [G loss: 4.0921711921691895]
Epoch [125/300] [D loss: 0.006707158405333757] [G loss: 5.244316577911377]
Epoch [126/300] [D loss: 0.03866109997034073] [G loss: 3.368199110031128]
Epoch [127/300] [D loss: 0.041117191314697266] [G loss: 4.484440326690674]
Epoch [128/300] [D loss: 0.0829429179430008] [G loss: 4.554262638092041]
Epoch [129/300] [D loss: 0.03219084441661835] [G loss: 5.4280924797058105]
Epoch [130/300] [D loss: 0.11037464439868927] [G loss: 5.89276647567749]
Epoch [131/300] [D loss: 0.029911085963249207] [G loss: 4.116299629211426]
Epoch [132/300] [D loss: 0.14276768267154694] [G loss: 2.059661626815796]
Epoch [133/300] [D loss: 0.06751281768083572] [G loss: 4.1591362953186035]
Epoch [134/300] [D loss: 0.06710615009069443] [G loss: 3.1725471019744873]
Epoch [135/300] [D loss: 0.015449777245521545] [G loss: 5.900448799133301]
Epoch [136/300] [D loss: 0.0017297605518251657] [G loss: 6.8876633644104]
Epoch [137/300] [D loss: 0.10661254078149796] [G loss: 3.035740613937378]
Epoch [138/300] [D loss: 0.04841696843504906] [G loss: 3.2598555088043213]
Epoch [139/300] [D loss: 0.13029193878173828] [G loss: 3.732114791870117]
Epoch [140/300] [D loss: 0.01422959566116333] [G loss: 4.98042106628418]
Epoch [141/300] [D loss: 0.15487617254257202] [G loss: 5.367415428161621]
Epoch [142/300] [D loss: 0.07540086656808853] [G loss: 4.3357768058776855]
Epoch [143/300] [D loss: 0.014456328004598618] [G loss: 4.569247245788574]
Epoch [144/300] [D loss: 0.012367785908281803] [G loss: 5.9672956466674805]
Epoch [145/300] [D loss: 0.05262265354394913] [G loss: 5.160377502441406]
Epoch [146/300] [D loss: 0.08042960613965988] [G loss: 3.7927441596984863]
Epoch [147/300] [D loss: 0.19245359301567078] [G loss: 3.8005473613739014]
Epoch [148/300] [D loss: 0.052174512296915054] [G loss: 5.132132053375244]
Epoch [149/300] [D loss: 0.4083835482597351] [G loss: 3.095195770263672]
Epoch [150/300] [D loss: 0.007787104230374098] [G loss: 7.455079078674316]
Epoch [151/300] [D loss: 0.011952079832553864] [G loss: 5.102141857147217]
Epoch [152/300] [D loss: 0.1612093597650528] [G loss: 3.7608675956726074]
Epoch [153/300] [D loss: 0.03018610179424286] [G loss: 3.8288230895996094]
Epoch [154/300] [D loss: 0.06719933450222015] [G loss: 4.006799697875977]
Epoch [155/300] [D loss: 0.0286514051258564] [G loss: 4.619848728179932]
Epoch [156/300] [D loss: 0.024552451446652412] [G loss: 4.437436580657959]
Epoch [157/300] [D loss: 0.011825334280729294] [G loss: 4.815029144287109]
Epoch [158/300] [D loss: 0.061660464853048325] [G loss: 7.883100509643555]
Epoch [159/300] [D loss: 0.041454415768384933] [G loss: 6.650402545928955]
Epoch [160/300] [D loss: 0.39040958881378174] [G loss: 8.09695053100586]
Epoch [161/300] [D loss: 0.0026854330208152533] [G loss: 8.107271194458008]
Epoch [162/300] [D loss: 0.16259369254112244] [G loss: 5.87791109085083]
Epoch [163/300] [D loss: 0.03663758188486099] [G loss: 4.121287822723389]
Epoch [164/300] [D loss: 0.009695476852357388] [G loss: 8.566814422607422]
Epoch [165/300] [D loss: 0.010842864401638508] [G loss: 7.692420959472656]
Epoch [166/300] [D loss: 0.010091769509017467] [G loss: 5.9158101081848145]
Epoch [167/300] [D loss: 0.005709683522582054] [G loss: 5.492888450622559]
Epoch [168/300] [D loss: 0.16688843071460724] [G loss: 3.3484747409820557]
Epoch [169/300] [D loss: 0.007227647118270397] [G loss: 6.33713960647583]
Epoch [170/300] [D loss: 0.007962928153574467] [G loss: 7.612416744232178]
Epoch [171/300] [D loss: 0.012646579183638096] [G loss: 4.420655250549316]
Epoch [172/300] [D loss: 0.01767764426767826] [G loss: 4.4174957275390625]
Epoch [173/300] [D loss: 0.006378074176609516] [G loss: 7.643772125244141]
Epoch [174/300] [D loss: 0.009910110384225845] [G loss: 5.333507061004639]
Epoch [175/300] [D loss: 0.004518002271652222] [G loss: 6.36816930770874]
Epoch [176/300] [D loss: 0.08845338225364685] [G loss: 4.761691570281982]
Epoch [177/300] [D loss: 0.038503680378198624] [G loss: 3.653679370880127]
Epoch [178/300] [D loss: 0.0021649880800396204] [G loss: 6.513932704925537]
Epoch [179/300] [D loss: 0.0054839057847857475] [G loss: 5.804437637329102]
Epoch [180/300] [D loss: 0.005088070873171091] [G loss: 5.903375148773193]
Epoch [181/300] [D loss: 0.024380924180150032] [G loss: 6.934257984161377]
Epoch [182/300] [D loss: 0.003647219855338335] [G loss: 9.193355560302734]
Epoch [183/300] [D loss: 0.8360736966133118] [G loss: 8.123100280761719]
Epoch [184/300] [D loss: 0.014819988049566746] [G loss: 4.3469648361206055]
Epoch [185/300] [D loss: 0.009622478857636452] [G loss: 5.201544761657715]
Epoch [186/300] [D loss: 0.023895107209682465] [G loss: 3.903581380844116]
Epoch [187/300] [D loss: 0.013679596595466137] [G loss: 8.605210304260254]
Epoch [188/300] [D loss: 0.0036324947141110897] [G loss: 6.411885738372803]
Epoch [189/300] [D loss: 0.006745172664523125] [G loss: 5.29392147064209]
Epoch [190/300] [D loss: 0.0007813140982761979] [G loss: 8.193427085876465]
Epoch [191/300] [D loss: 0.021813858300447464] [G loss: 4.648034572601318]
Epoch [192/300] [D loss: 0.025777161121368408] [G loss: 4.67152738571167]
Epoch [193/300] [D loss: 0.06395631283521652] [G loss: 7.985042095184326]
Epoch [194/300] [D loss: 0.034654516726732254] [G loss: 3.360792398452759]
Epoch [195/300] [D loss: 0.26737672090530396] [G loss: 6.765297889709473]
Epoch [196/300] [D loss: 0.010468905791640282] [G loss: 5.34564208984375]
Epoch [197/300] [D loss: 0.014369252137839794] [G loss: 5.097072124481201]
Epoch [198/300] [D loss: 0.003273996990174055] [G loss: 6.472024440765381]
Epoch [199/300] [D loss: 0.005874062888324261] [G loss: 8.4591646194458]
Epoch [200/300] [D loss: 0.005507076624780893] [G loss: 5.7223286628723145]
Epoch [201/300] [D loss: 0.16853176057338715] [G loss: 1.9387050867080688]
Epoch [202/300] [D loss: 0.0023364669177681208] [G loss: 8.370942115783691]
Epoch [203/300] [D loss: 0.003936069551855326] [G loss: 7.522141933441162]
Epoch [204/300] [D loss: 0.01826675795018673] [G loss: 4.6409101486206055]
Epoch [205/300] [D loss: 0.018070252612233162] [G loss: 6.2785234451293945]
Epoch [206/300] [D loss: 0.06540463864803314] [G loss: 4.250749111175537]
Epoch [207/300] [D loss: 0.005754987709224224] [G loss: 5.474653720855713]
Epoch [208/300] [D loss: 0.0024513285607099533] [G loss: 6.821662425994873]
Epoch [209/300] [D loss: 0.005051593761891127] [G loss: 8.622801780700684]
Epoch [210/300] [D loss: 0.2648685872554779] [G loss: 1.4338374137878418]
Epoch [211/300] [D loss: 0.06582126766443253] [G loss: 4.042891502380371]
Epoch [212/300] [D loss: 0.033716216683387756] [G loss: 3.6866607666015625]
Epoch [213/300] [D loss: 0.008300993591547012] [G loss: 5.592546463012695]
Epoch [214/300] [D loss: 0.10640338063240051] [G loss: 3.440943479537964]
Epoch [215/300] [D loss: 0.018705546855926514] [G loss: 8.040839195251465]
Epoch [216/300] [D loss: 0.32254651188850403] [G loss: 1.023318886756897]
Epoch [217/300] [D loss: 0.006875279359519482] [G loss: 5.205789566040039]
Epoch [218/300] [D loss: 0.01632297970354557] [G loss: 6.327811241149902]
Epoch [219/300] [D loss: 0.020900549367070198] [G loss: 6.634525299072266]
Epoch [220/300] [D loss: 0.011139878071844578] [G loss: 7.300896644592285]
Epoch [221/300] [D loss: 0.01837160252034664] [G loss: 5.964895248413086]
Epoch [222/300] [D loss: 0.016974858939647675] [G loss: 4.413552284240723]
Epoch [223/300] [D loss: 0.3439306914806366] [G loss: 5.5219573974609375]
Epoch [224/300] [D loss: 0.047548823058605194] [G loss: 6.586645603179932]
Epoch [225/300] [D loss: 0.03183538839221001] [G loss: 4.398618221282959]
Epoch [226/300] [D loss: 0.0033374489285051823] [G loss: 7.412342071533203]
Epoch [227/300] [D loss: 0.018537862226366997] [G loss: 5.484577655792236]
Epoch [228/300] [D loss: 0.03582551330327988] [G loss: 3.6857614517211914]
Epoch [229/300] [D loss: 0.11226078867912292] [G loss: 2.819861888885498]
Epoch [230/300] [D loss: 0.002012553857639432] [G loss: 7.154722690582275]
Epoch [231/300] [D loss: 0.00868014432489872] [G loss: 8.001018524169922]
Epoch [232/300] [D loss: 0.0419110469520092] [G loss: 6.980061054229736]
Epoch [233/300] [D loss: 0.006477241404354572] [G loss: 5.782578945159912]
Epoch [234/300] [D loss: 0.0016205032588914037] [G loss: 10.428010940551758]
Epoch [235/300] [D loss: 0.02312217839062214] [G loss: 4.159178733825684]
Epoch [236/300] [D loss: 0.36001917719841003] [G loss: 2.4811325073242188]
Epoch [237/300] [D loss: 0.005733223166316748] [G loss: 5.611016750335693]
Epoch [238/300] [D loss: 0.008837449364364147] [G loss: 8.30731201171875]
Epoch [239/300] [D loss: 0.011222743429243565] [G loss: 4.619396209716797]
Epoch [240/300] [D loss: 0.0060098664835095406] [G loss: 6.022060394287109]
Epoch [241/300] [D loss: 0.0011382169323042035] [G loss: 7.404472351074219]
Epoch [242/300] [D loss: 0.3661719560623169] [G loss: 7.876453399658203]
Epoch [243/300] [D loss: 0.0019019388128072023] [G loss: 7.3895263671875]
Epoch [244/300] [D loss: 0.006632590666413307] [G loss: 5.541728973388672]
Epoch [245/300] [D loss: 0.008930223993957043] [G loss: 5.2114691734313965]
Epoch [246/300] [D loss: 0.016119416803121567] [G loss: 7.121890068054199]
Epoch [247/300] [D loss: 0.001622633310034871] [G loss: 7.303770065307617]
Epoch [248/300] [D loss: 0.005070182494819164] [G loss: 6.975015640258789]
Epoch [249/300] [D loss: 0.04641895741224289] [G loss: 7.218448638916016]
Epoch [250/300] [D loss: 0.01194002851843834] [G loss: 4.6930975914001465]
Epoch [251/300] [D loss: 0.012792033143341541] [G loss: 4.67077112197876]
Epoch [252/300] [D loss: 0.008810436353087425] [G loss: 5.938291072845459]
Epoch [253/300] [D loss: 0.010516034439206123] [G loss: 4.816621780395508]
Epoch [254/300] [D loss: 0.017264991998672485] [G loss: 8.856822967529297]
Epoch [255/300] [D loss: 0.011463891714811325] [G loss: 6.232043743133545]
Epoch [256/300] [D loss: 0.08137447386980057] [G loss: 2.598818778991699]
Epoch [257/300] [D loss: 0.032363615930080414] [G loss: 4.790830135345459]
Epoch [258/300] [D loss: 0.00863250344991684] [G loss: 7.292766571044922]
Epoch [259/300] [D loss: 0.027235930785536766] [G loss: 6.844869613647461]
Epoch [260/300] [D loss: 0.008849331177771091] [G loss: 5.027510643005371]
Epoch [261/300] [D loss: 0.020822376012802124] [G loss: 4.600456714630127]
Epoch [262/300] [D loss: 1.7667120695114136] [G loss: 3.4651100635528564]
Epoch [263/300] [D loss: 0.022669170051813126] [G loss: 5.7553019523620605]
Epoch [264/300] [D loss: 0.01582598127424717] [G loss: 4.149420261383057]
Epoch [265/300] [D loss: 0.0035504011902958155] [G loss: 6.116427421569824]
Epoch [266/300] [D loss: 0.07644154131412506] [G loss: 2.720405101776123]
Epoch [267/300] [D loss: 0.030415533110499382] [G loss: 4.244810104370117]
Epoch [268/300] [D loss: 0.020068874582648277] [G loss: 6.474517822265625]
Epoch [269/300] [D loss: 0.002136750379577279] [G loss: 9.29329776763916]
Epoch [270/300] [D loss: 0.00978941936045885] [G loss: 5.02622652053833]
Epoch [271/300] [D loss: 0.08784317970275879] [G loss: 6.733256816864014]
Epoch [272/300] [D loss: 0.009109925478696823] [G loss: 5.823270797729492]
Epoch [273/300] [D loss: 0.008865194395184517] [G loss: 5.696066379547119]
Epoch [274/300] [D loss: 0.029590584337711334] [G loss: 8.216507911682129]
Epoch [275/300] [D loss: 0.0636298805475235] [G loss: 8.98292064666748]
Epoch [276/300] [D loss: 0.004769572988152504] [G loss: 6.2220025062561035]
Epoch [277/300] [D loss: 0.003883387427777052] [G loss: 6.5977911949157715]
Epoch [278/300] [D loss: 0.04028937965631485] [G loss: 4.9343485832214355]
Epoch [279/300] [D loss: 0.011857430450618267] [G loss: 6.440511703491211]
Epoch [280/300] [D loss: 0.007019379176199436] [G loss: 5.2130351066589355]
Epoch [281/300] [D loss: 0.022525882348418236] [G loss: 3.9527556896209717]
Epoch [282/300] [D loss: 0.0071130781434476376] [G loss: 6.993907928466797]
Epoch [283/300] [D loss: 0.003977011889219284] [G loss: 7.2447967529296875]
Epoch [284/300] [D loss: 0.07062061131000519] [G loss: 5.2334771156311035]
Epoch [285/300] [D loss: 0.01805986650288105] [G loss: 5.5015082359313965]
Epoch [286/300] [D loss: 0.05663669481873512] [G loss: 6.766615390777588]
Epoch [287/300] [D loss: 0.0032901568338274956] [G loss: 6.28628396987915]
Epoch [288/300] [D loss: 0.3530406653881073] [G loss: 7.906818389892578]
Epoch [289/300] [D loss: 0.004547123331576586] [G loss: 6.108604431152344]
Epoch [290/300] [D loss: 0.010472457855939865] [G loss: 6.213746070861816]
Epoch [291/300] [D loss: 0.016601260751485825] [G loss: 5.763346195220947]
Epoch [292/300] [D loss: 0.04024907946586609] [G loss: 5.658637523651123]
Epoch [293/300] [D loss: 0.07437323033809662] [G loss: 5.68184757232666]
Epoch [294/300] [D loss: 0.08150847256183624] [G loss: 6.040549278259277]
Epoch [295/300] [D loss: 0.0924491435289383] [G loss: 2.502917766571045]
Epoch [296/300] [D loss: 0.0035814237780869007] [G loss: 7.250881195068359]
Epoch [297/300] [D loss: 0.012245922349393368] [G loss: 6.780396461486816]
Epoch [298/300] [D loss: 0.004009227734059095] [G loss: 5.833404064178467]
Epoch [299/300] [D loss: 0.14272907376289368] [G loss: 7.528534889221191]
In [11]:
# 结果评估与可视化
def visualize_results():
    
    X, y, mask = next(iter(dataloader))
    X, y, mask = X.to(device), y.to(device), mask.to(device)
    generated_images = generator(X, mask)

    mask = mask.squeeze(1)
    generated_images = generated_images.squeeze(1)
    y = y.squeeze(1)

    final_output = generated_images

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title('Masked NO2 Data')
    plt.imshow(X[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title('Generated NO2 Data')
    plt.imshow(final_output[0].cpu().detach().numpy(), cmap='gray')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title('Original NO2 Data')
    plt.imshow(y[0].cpu().detach().numpy(), cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig('results_visualizationxxx.png')
    plt.close()
In [12]:
dataset_test = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')
test_loader = DataLoader(dataset_test, batch_size=64, shuffle=False, num_workers=8)
In [13]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [17]:
def cal_ioa(y_true, y_pred):
    # 计算平均值
    mean_observed = np.mean(y_true)
    mean_predicted = np.mean(y_pred)

    # 计算IoA
    numerator = np.sum((y_true - y_pred) ** 2)
    denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)
    IoA = 1 - (numerator / denominator)

    return IoA
In [21]:
device = 'cpu'
generator = generator.to(device)
eva_list = list()
with torch.no_grad():
    for X, y, mask in test_loader:
        X, y, mask = X.to(device), y.to(device), mask.to(device)
        generated_images = generator(X, mask)
        mask = mask.squeeze(1).cpu().detach().numpy()
        rev_mask = (mask==0)* 1
        generated_images = generated_images.squeeze(1)
        real = y.squeeze(1).cpu().detach().numpy() * max_pixel_value
        final_output = generated_images.cpu().detach().numpy()
        final_output *= max_pixel_value
        # y_pred = final_output[rev_mask==1].tolist()
        # y_real = real[rev_mask==1].tolist()
        for i, sample in enumerate(generated_images):
            used_mask = rev_mask[i]
            data_label = real[i] * used_mask
            recon_no2 = final_output[i] * used_mask
            data_label = data_label[used_mask==1]
            recon_no2 = recon_no2[used_mask==1]
            mae = mean_absolute_error(data_label, recon_no2)
            rmse = np.sqrt(mean_squared_error(data_label, recon_no2))
            mape = mean_absolute_percentage_error(data_label, recon_no2)
            r2 = r2_score(data_label, recon_no2)
            ioa = cal_ioa(data_label, recon_no2)
            r = np.corrcoef(data_label, recon_no2)[0, 1]
            eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
In [15]:
import pandas as pd
In [24]:
pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()
Out[24]:
mae rmse mape r2 ioa r
count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000
mean 2.512144 3.430941 0.360515 -0.342186 0.680466 0.578431
std 1.184403 1.580097 0.338929 1.730534 0.267197 0.227716
min 0.895772 1.189229 0.126946 -42.147773 -2.040257 -0.542623
25% 1.699544 2.389879 0.211304 -0.480435 0.606881 0.457522
50% 2.259834 3.125452 0.263535 0.094196 0.749986 0.620338
75% 2.953516 3.983247 0.358762 0.421505 0.840619 0.745614
max 10.477497 14.460713 4.314635 0.922679 0.981525 0.965753
In [11]:
# 保存训练好的模型
torch.save(generator, './models/GAN/generator.pt')
torch.save(discriminator, './models/GAN/discriminator.pt')
In [14]:
test_imgs = [x for x in os.listdir('./test_img/') if 'img' in x]
test_imgs.sort()
test_masks = [x for x in os.listdir('./test_img/') if 'mask' in x]
test_masks.sort()
for img_npy, mask_npy in zip(test_imgs, test_masks):
    img = np.load(f'./test_img/{file}')
    img_in = torch.tensor(np.expand_dims(img, 0), dtype=torch.float32)
    mask = np.load(f'./test_img/{file}')
    mask_in = torch.tensor(np.expand_dims(mask, 0), dtype=torch.float32)
    out = generator(img_in, mask_in).detach().cpu().numpy() * max_pixel_value
    break
In [15]:
plt.imshow(out[0][0], cmap='RdYlGn_r')
Out[15]:
<matplotlib.image.AxesImage at 0x7f793dd49520>
No description has been provided for this image
In [ ]: