93 KiB
93 KiB
In [2]:
import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from PIL import Image import matplotlib.pyplot as plt import numpy as np import pandas as pd os.environ["CUDA_VISIBLE_DEVICE"] = "0"
In [3]:
# 设置CUDA设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}")
Using device: cuda
In [12]:
# 定义函数来找到最大值 def find_max_pixel_value(image_dir): max_pixel_value = 0.0 for filename in os.listdir(image_dir): if filename.endswith('.npy'): image_path = os.path.join(image_dir, filename) image = np.load(image_path).astype(np.float32) max_pixel_value = max(max_pixel_value, image[:, :, 0].max()) return max_pixel_value # 计算图像数据中的最大像素值 image_dir = './out_mat/96/train/' max_pixel_value = 107.49169921875 print(f"Maximum pixel value in the dataset: {max_pixel_value}")
Maximum pixel value in the dataset: 107.49169921875
In [5]:
np.random.seed(42) torch.random.manual_seed(42)
Out[5]:
<torch._C.Generator at 0x7f7a40059f90>
In [6]:
class NO2Dataset(Dataset): def __init__(self, image_dir, mask_dir): self.image_dir = image_dir self.mask_dir = mask_dir self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件 self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件 def __len__(self): return len(self.image_filenames) def __getitem__(self, idx): image_path = os.path.join(self.image_dir, self.image_filenames[idx]) mask_idx = np.random.choice(self.mask_filenames) mask_path = os.path.join(self.mask_dir, mask_idx) # 加载图像数据 (.npy 文件) image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 8) # 加载掩码数据 (.jpg 文件) mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32) # 将掩码数据中非0值设为1,0值保持不变 mask = np.where(mask != 0, 1.0, 0.0) # 保持掩码数据形状为 (96, 96, 1) mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1) # 应用掩码 masked_image = image.copy() masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据 # cGAN的输入和目标 X = np.concatenate([masked_image[:, :, :1], image[:, :, 1:]], axis=-1) # 形状为 (96, 96, 8) y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1) # 转换形状为 (channels, height, width) X = np.transpose(X, (2, 0, 1)) # 转换为 (8, 96, 96) y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96) mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96) return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32) # 实例化数据集和数据加载器 image_dir = './out_mat/96/train/' mask_dir = './out_mat/96/mask/20/' print(f"checkpoint before Generator is OK")
checkpoint before Generator is OK
In [7]:
dataset = NO2Dataset(image_dir, mask_dir) dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8) # 生成器模型 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(64), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(128), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(64), nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x, mask): x_encoded = self.encoder(x) x_decoded = self.decoder(x_encoded) x_decoded = (x_decoded + 1) / 2 x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :] return x_output # 判别器模型 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(64), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(128), nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid(), ) def forward(self, x): return self.model(x) # 将模型加载到GPU generator = Generator().to(device) discriminator = Discriminator().to(device)
In [8]:
# 定义优化器和损失函数 optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) adversarial_loss = nn.BCELoss().to(device) pixelwise_loss = nn.MSELoss().to(device) # 确认模型是否在GPU上 print(f"Generator is on: {next(generator.parameters()).device}") print(f"Discriminator is on: {next(discriminator.parameters()).device}")
Generator is on: cuda:0 Discriminator is on: cuda:0
In [9]:
gen = torch.load('./models/GAN/generator.pth', map_location='cpu') generator.load_state_dict(gen) generator = generator.to('cpu')
In [10]:
dis = torch.load('./models/GAN/discriminator.pth', map_location='cpu') discriminator.load_state_dict(dis)
Out[10]:
<All keys matched successfully>
In [10]:
# 开始训练 epochs = 300 for epoch in range(epochs): for i, (X, y, mask) in enumerate(dataloader): # 将数据移到 GPU 上 X, y, mask = X.to(device), y.to(device), mask.to(device) valid = torch.ones((X.size(0), 1, 12, 12)).to(device) fake = torch.zeros((X.size(0), 1, 12, 12)).to(device) # 生成器生成图像 optimizer_G.zero_grad() generated_images = generator(X, mask) g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * pixelwise_loss( generated_images, y) g_loss.backward() optimizer_G.step() # 判别器训练 optimizer_D.zero_grad() real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid) fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() optimizer_D.step() print(f"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]") # 保存训练好的模型 torch.save(generator.state_dict(), './models/GAN/generator.pth') torch.save(discriminator.state_dict(), './models/GAN/discriminator.pth')
/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.) return F.conv2d(input, weight, bias, self.stride,
Epoch [0/300] [D loss: 0.5392551422119141] [G loss: 0.9313964247703552] Epoch [1/300] [D loss: 0.45492202043533325] [G loss: 1.5053317546844482] Epoch [2/300] [D loss: 0.3420121669769287] [G loss: 1.4923288822174072] Epoch [3/300] [D loss: 0.2960708737373352] [G loss: 1.955796480178833] Epoch [4/300] [D loss: 0.40790891647338867] [G loss: 2.071624279022217] Epoch [5/300] [D loss: 0.2747359275817871] [G loss: 1.917580485343933] Epoch [6/300] [D loss: 0.47008591890335083] [G loss: 1.5003858804702759] Epoch [7/300] [D loss: 0.19478999078273773] [G loss: 3.949864149093628] Epoch [8/300] [D loss: 0.5340784192085266] [G loss: 0.8913870453834534] Epoch [9/300] [D loss: 0.3194230794906616] [G loss: 1.861933946609497] Epoch [10/300] [D loss: 0.22022968530654907] [G loss: 1.6654534339904785] Epoch [11/300] [D loss: 0.3743482828140259] [G loss: 1.626413106918335] Epoch [12/300] [D loss: 0.13774043321609497] [G loss: 3.187469482421875] Epoch [13/300] [D loss: 0.4275822043418884] [G loss: 2.2269718647003174] Epoch [14/300] [D loss: 0.22367843985557556] [G loss: 2.707667827606201] Epoch [15/300] [D loss: 0.25350409746170044] [G loss: 1.6780041456222534] Epoch [16/300] [D loss: 0.23647311329841614] [G loss: 1.6349072456359863] Epoch [17/300] [D loss: 0.5604373812675476] [G loss: 2.025310754776001] Epoch [18/300] [D loss: 0.4707084596157074] [G loss: 2.938746452331543] Epoch [19/300] [D loss: 0.2135343998670578] [G loss: 1.438072919845581] Epoch [20/300] [D loss: 0.08090662956237793] [G loss: 2.77827787399292] Epoch [21/300] [D loss: 0.2995302677154541] [G loss: 2.239123821258545] Epoch [22/300] [D loss: 0.45718979835510254] [G loss: 3.404195547103882] Epoch [23/300] [D loss: 0.13950982689857483] [G loss: 2.7335846424102783] Epoch [24/300] [D loss: 0.1814766675233841] [G loss: 2.7603776454925537] Epoch [25/300] [D loss: 0.0551619790494442] [G loss: 4.066004276275635] Epoch [26/300] [D loss: 0.1498052179813385] [G loss: 2.733922243118286] Epoch [27/300] [D loss: 0.11236032098531723] [G loss: 3.215831756591797] Epoch [28/300] [D loss: 0.4945942461490631] [G loss: 1.9661915302276611] Epoch [29/300] [D loss: 0.0760776624083519] [G loss: 2.7212204933166504] Epoch [30/300] [D loss: 0.19911707937717438] [G loss: 1.5356297492980957] Epoch [31/300] [D loss: 0.0900304988026619] [G loss: 2.84740948677063] Epoch [32/300] [D loss: 0.05910511314868927] [G loss: 4.071162223815918] Epoch [33/300] [D loss: 0.0711229220032692] [G loss: 2.716848373413086] Epoch [34/300] [D loss: 0.39897823333740234] [G loss: 1.5674937963485718] Epoch [35/300] [D loss: 0.05552360787987709] [G loss: 3.754516124725342] Epoch [36/300] [D loss: 0.5413599014282227] [G loss: 4.129798889160156] Epoch [37/300] [D loss: 0.17664434015750885] [G loss: 4.226540565490723] Epoch [38/300] [D loss: 0.15215317904949188] [G loss: 2.6023881435394287] Epoch [39/300] [D loss: 0.07798739522695541] [G loss: 4.075980186462402] Epoch [40/300] [D loss: 0.03936776518821716] [G loss: 4.37988805770874] Epoch [41/300] [D loss: 0.2012120634317398] [G loss: 2.1987271308898926] Epoch [42/300] [D loss: 0.05274203419685364] [G loss: 3.8458642959594727] Epoch [43/300] [D loss: 0.13967157900333405] [G loss: 3.438344955444336] Epoch [44/300] [D loss: 0.05800560116767883] [G loss: 2.941135883331299] Epoch [45/300] [D loss: 0.14671097695827484] [G loss: 3.388277292251587] Epoch [46/300] [D loss: 0.06439051032066345] [G loss: 2.9789438247680664] Epoch [47/300] [D loss: 0.11101078987121582] [G loss: 2.6266937255859375] Epoch [48/300] [D loss: 0.028554894030094147] [G loss: 4.042592525482178] Epoch [49/300] [D loss: 0.3364626169204712] [G loss: 3.419842004776001] Epoch [50/300] [D loss: 0.2501979470252991] [G loss: 3.319307804107666] Epoch [51/300] [D loss: 0.2962917387485504] [G loss: 5.088353157043457] Epoch [52/300] [D loss: 0.07700179517269135] [G loss: 3.231515884399414] Epoch [53/300] [D loss: 0.4093267321586609] [G loss: 1.918235182762146] Epoch [54/300] [D loss: 0.12105419486761093] [G loss: 2.3409922122955322] Epoch [55/300] [D loss: 0.057456158101558685] [G loss: 4.047771453857422] Epoch [56/300] [D loss: 0.250449538230896] [G loss: 1.9442336559295654] Epoch [57/300] [D loss: 0.08125491440296173] [G loss: 2.7323458194732666] Epoch [58/300] [D loss: 0.06671395897865295] [G loss: 3.081458330154419] Epoch [59/300] [D loss: 0.06982511281967163] [G loss: 3.95278000831604] Epoch [60/300] [D loss: 0.08973922580480576] [G loss: 3.9550158977508545] Epoch [61/300] [D loss: 0.29226893186569214] [G loss: 2.2824535369873047] Epoch [62/300] [D loss: 0.06800767779350281] [G loss: 4.67025089263916] Epoch [63/300] [D loss: 0.017987174913287163] [G loss: 4.119121551513672] Epoch [64/300] [D loss: 0.1278763711452484] [G loss: 4.481695652008057] Epoch [65/300] [D loss: 0.12277506291866302] [G loss: 2.0188961029052734] Epoch [66/300] [D loss: 0.10042040050029755] [G loss: 4.019499778747559] Epoch [67/300] [D loss: 0.15092261135578156] [G loss: 3.0588033199310303] Epoch [68/300] [D loss: 0.157196044921875] [G loss: 4.579256534576416] Epoch [69/300] [D loss: 0.0256386436522007] [G loss: 4.309335708618164] Epoch [70/300] [D loss: 0.011956267058849335] [G loss: 4.763312816619873] Epoch [71/300] [D loss: 0.08460590243339539] [G loss: 5.456184387207031] Epoch [72/300] [D loss: 0.07495025545358658] [G loss: 3.5078511238098145] Epoch [73/300] [D loss: 0.13037167489528656] [G loss: 3.164292812347412] Epoch [74/300] [D loss: 0.0830327719449997] [G loss: 5.159647464752197] Epoch [75/300] [D loss: 0.4353921115398407] [G loss: 5.0652875900268555] Epoch [76/300] [D loss: 0.02432486228644848] [G loss: 3.7066524028778076] Epoch [77/300] [D loss: 0.2809848189353943] [G loss: 1.1604290008544922] Epoch [78/300] [D loss: 0.7653636932373047] [G loss: 2.5745716094970703] Epoch [79/300] [D loss: 0.041840165853500366] [G loss: 4.082228660583496] Epoch [80/300] [D loss: 0.03992146998643875] [G loss: 4.9236321449279785] Epoch [81/300] [D loss: 0.1003192886710167] [G loss: 2.683060646057129] Epoch [82/300] [D loss: 0.1460535228252411] [G loss: 4.597597122192383] Epoch [83/300] [D loss: 0.1408858597278595] [G loss: 1.8829160928726196] Epoch [84/300] [D loss: 0.048089221119880676] [G loss: 3.1438090801239014] Epoch [85/300] [D loss: 0.041934601962566376] [G loss: 3.298645257949829] Epoch [86/300] [D loss: 0.1363355964422226] [G loss: 2.6124517917633057] Epoch [87/300] [D loss: 0.03299988433718681] [G loss: 3.3402161598205566] Epoch [88/300] [D loss: 0.22786922752857208] [G loss: 3.9778051376342773] Epoch [89/300] [D loss: 0.021804900839924812] [G loss: 4.595890045166016] Epoch [90/300] [D loss: 0.022495444864034653] [G loss: 4.2465901374816895] Epoch [91/300] [D loss: 0.02908019907772541] [G loss: 6.379057884216309] Epoch [92/300] [D loss: 0.6523040533065796] [G loss: 0.6009750962257385] Epoch [93/300] [D loss: 0.007557982578873634] [G loss: 5.837783336639404] Epoch [94/300] [D loss: 0.020063551142811775] [G loss: 4.044745445251465] Epoch [95/300] [D loss: 0.003706925082951784] [G loss: 8.243224143981934] Epoch [96/300] [D loss: 0.021942533552646637] [G loss: 4.662309169769287] Epoch [97/300] [D loss: 0.005410192534327507] [G loss: 5.5743536949157715] Epoch [98/300] [D loss: 0.07137680053710938] [G loss: 3.261455535888672] Epoch [99/300] [D loss: 0.11327817291021347] [G loss: 3.817570686340332] Epoch [100/300] [D loss: 0.04488084092736244] [G loss: 4.458094596862793] Epoch [101/300] [D loss: 0.05757671222090721] [G loss: 3.695896625518799] Epoch [102/300] [D loss: 0.04083157703280449] [G loss: 3.704172134399414] Epoch [103/300] [D loss: 0.02816752716898918] [G loss: 4.322700023651123] Epoch [104/300] [D loss: 0.026689285412430763] [G loss: 4.115890979766846] Epoch [105/300] [D loss: 0.03571446239948273] [G loss: 4.080765724182129] Epoch [106/300] [D loss: 0.020453810691833496] [G loss: 5.457651615142822] Epoch [107/300] [D loss: 0.03774755448102951] [G loss: 5.34019136428833] Epoch [108/300] [D loss: 0.0933525487780571] [G loss: 5.5797905921936035] Epoch [109/300] [D loss: 0.024301748722791672] [G loss: 4.042290210723877] Epoch [110/300] [D loss: 0.9034162759780884] [G loss: 5.52556848526001] Epoch [111/300] [D loss: 0.0911281406879425] [G loss: 6.487083911895752] Epoch [112/300] [D loss: 0.13892149925231934] [G loss: 3.0797510147094727] Epoch [113/300] [D loss: 0.09627098590135574] [G loss: 3.104957103729248] Epoch [114/300] [D loss: 0.007696065586060286] [G loss: 6.618851184844971] Epoch [115/300] [D loss: 0.06528083980083466] [G loss: 3.4506514072418213] Epoch [116/300] [D loss: 0.03879600390791893] [G loss: 3.3708789348602295] Epoch [117/300] [D loss: 0.03395622968673706] [G loss: 6.2684736251831055] Epoch [118/300] [D loss: 0.010569067671895027] [G loss: 5.944631099700928] Epoch [119/300] [D loss: 0.024817001074552536] [G loss: 6.614266872406006] Epoch [120/300] [D loss: 0.013173197396099567] [G loss: 6.226423263549805] Epoch [121/300] [D loss: 0.06546411663293839] [G loss: 3.0585291385650635] Epoch [122/300] [D loss: 0.01085597462952137] [G loss: 6.437295913696289] Epoch [123/300] [D loss: 0.03522876650094986] [G loss: 4.0734052658081055] Epoch [124/300] [D loss: 0.06875205039978027] [G loss: 4.0921711921691895] Epoch [125/300] [D loss: 0.006707158405333757] [G loss: 5.244316577911377] Epoch [126/300] [D loss: 0.03866109997034073] [G loss: 3.368199110031128] Epoch [127/300] [D loss: 0.041117191314697266] [G loss: 4.484440326690674] Epoch [128/300] [D loss: 0.0829429179430008] [G loss: 4.554262638092041] Epoch [129/300] [D loss: 0.03219084441661835] [G loss: 5.4280924797058105] Epoch [130/300] [D loss: 0.11037464439868927] [G loss: 5.89276647567749] Epoch [131/300] [D loss: 0.029911085963249207] [G loss: 4.116299629211426] Epoch [132/300] [D loss: 0.14276768267154694] [G loss: 2.059661626815796] Epoch [133/300] [D loss: 0.06751281768083572] [G loss: 4.1591362953186035] Epoch [134/300] [D loss: 0.06710615009069443] [G loss: 3.1725471019744873] Epoch [135/300] [D loss: 0.015449777245521545] [G loss: 5.900448799133301] Epoch [136/300] [D loss: 0.0017297605518251657] [G loss: 6.8876633644104] Epoch [137/300] [D loss: 0.10661254078149796] [G loss: 3.035740613937378] Epoch [138/300] [D loss: 0.04841696843504906] [G loss: 3.2598555088043213] Epoch [139/300] [D loss: 0.13029193878173828] [G loss: 3.732114791870117] Epoch [140/300] [D loss: 0.01422959566116333] [G loss: 4.98042106628418] Epoch [141/300] [D loss: 0.15487617254257202] [G loss: 5.367415428161621] Epoch [142/300] [D loss: 0.07540086656808853] [G loss: 4.3357768058776855] Epoch [143/300] [D loss: 0.014456328004598618] [G loss: 4.569247245788574] Epoch [144/300] [D loss: 0.012367785908281803] [G loss: 5.9672956466674805] Epoch [145/300] [D loss: 0.05262265354394913] [G loss: 5.160377502441406] Epoch [146/300] [D loss: 0.08042960613965988] [G loss: 3.7927441596984863] Epoch [147/300] [D loss: 0.19245359301567078] [G loss: 3.8005473613739014] Epoch [148/300] [D loss: 0.052174512296915054] [G loss: 5.132132053375244] Epoch [149/300] [D loss: 0.4083835482597351] [G loss: 3.095195770263672] Epoch [150/300] [D loss: 0.007787104230374098] [G loss: 7.455079078674316] Epoch [151/300] [D loss: 0.011952079832553864] [G loss: 5.102141857147217] Epoch [152/300] [D loss: 0.1612093597650528] [G loss: 3.7608675956726074] Epoch [153/300] [D loss: 0.03018610179424286] [G loss: 3.8288230895996094] Epoch [154/300] [D loss: 0.06719933450222015] [G loss: 4.006799697875977] Epoch [155/300] [D loss: 0.0286514051258564] [G loss: 4.619848728179932] Epoch [156/300] [D loss: 0.024552451446652412] [G loss: 4.437436580657959] Epoch [157/300] [D loss: 0.011825334280729294] [G loss: 4.815029144287109] Epoch [158/300] [D loss: 0.061660464853048325] [G loss: 7.883100509643555] Epoch [159/300] [D loss: 0.041454415768384933] [G loss: 6.650402545928955] Epoch [160/300] [D loss: 0.39040958881378174] [G loss: 8.09695053100586] Epoch [161/300] [D loss: 0.0026854330208152533] [G loss: 8.107271194458008] Epoch [162/300] [D loss: 0.16259369254112244] [G loss: 5.87791109085083] Epoch [163/300] [D loss: 0.03663758188486099] [G loss: 4.121287822723389] Epoch [164/300] [D loss: 0.009695476852357388] [G loss: 8.566814422607422] Epoch [165/300] [D loss: 0.010842864401638508] [G loss: 7.692420959472656] Epoch [166/300] [D loss: 0.010091769509017467] [G loss: 5.9158101081848145] Epoch [167/300] [D loss: 0.005709683522582054] [G loss: 5.492888450622559] Epoch [168/300] [D loss: 0.16688843071460724] [G loss: 3.3484747409820557] Epoch [169/300] [D loss: 0.007227647118270397] [G loss: 6.33713960647583] Epoch [170/300] [D loss: 0.007962928153574467] [G loss: 7.612416744232178] Epoch [171/300] [D loss: 0.012646579183638096] [G loss: 4.420655250549316] Epoch [172/300] [D loss: 0.01767764426767826] [G loss: 4.4174957275390625] Epoch [173/300] [D loss: 0.006378074176609516] [G loss: 7.643772125244141] Epoch [174/300] [D loss: 0.009910110384225845] [G loss: 5.333507061004639] Epoch [175/300] [D loss: 0.004518002271652222] [G loss: 6.36816930770874] Epoch [176/300] [D loss: 0.08845338225364685] [G loss: 4.761691570281982] Epoch [177/300] [D loss: 0.038503680378198624] [G loss: 3.653679370880127] Epoch [178/300] [D loss: 0.0021649880800396204] [G loss: 6.513932704925537] Epoch [179/300] [D loss: 0.0054839057847857475] [G loss: 5.804437637329102] Epoch [180/300] [D loss: 0.005088070873171091] [G loss: 5.903375148773193] Epoch [181/300] [D loss: 0.024380924180150032] [G loss: 6.934257984161377] Epoch [182/300] [D loss: 0.003647219855338335] [G loss: 9.193355560302734] Epoch [183/300] [D loss: 0.8360736966133118] [G loss: 8.123100280761719] Epoch [184/300] [D loss: 0.014819988049566746] [G loss: 4.3469648361206055] Epoch [185/300] [D loss: 0.009622478857636452] [G loss: 5.201544761657715] Epoch [186/300] [D loss: 0.023895107209682465] [G loss: 3.903581380844116] Epoch [187/300] [D loss: 0.013679596595466137] [G loss: 8.605210304260254] Epoch [188/300] [D loss: 0.0036324947141110897] [G loss: 6.411885738372803] Epoch [189/300] [D loss: 0.006745172664523125] [G loss: 5.29392147064209] Epoch [190/300] [D loss: 0.0007813140982761979] [G loss: 8.193427085876465] Epoch [191/300] [D loss: 0.021813858300447464] [G loss: 4.648034572601318] Epoch [192/300] [D loss: 0.025777161121368408] [G loss: 4.67152738571167] Epoch [193/300] [D loss: 0.06395631283521652] [G loss: 7.985042095184326] Epoch [194/300] [D loss: 0.034654516726732254] [G loss: 3.360792398452759] Epoch [195/300] [D loss: 0.26737672090530396] [G loss: 6.765297889709473] Epoch [196/300] [D loss: 0.010468905791640282] [G loss: 5.34564208984375] Epoch [197/300] [D loss: 0.014369252137839794] [G loss: 5.097072124481201] Epoch [198/300] [D loss: 0.003273996990174055] [G loss: 6.472024440765381] Epoch [199/300] [D loss: 0.005874062888324261] [G loss: 8.4591646194458] Epoch [200/300] [D loss: 0.005507076624780893] [G loss: 5.7223286628723145] Epoch [201/300] [D loss: 0.16853176057338715] [G loss: 1.9387050867080688] Epoch [202/300] [D loss: 0.0023364669177681208] [G loss: 8.370942115783691] Epoch [203/300] [D loss: 0.003936069551855326] [G loss: 7.522141933441162] Epoch [204/300] [D loss: 0.01826675795018673] [G loss: 4.6409101486206055] Epoch [205/300] [D loss: 0.018070252612233162] [G loss: 6.2785234451293945] Epoch [206/300] [D loss: 0.06540463864803314] [G loss: 4.250749111175537] Epoch [207/300] [D loss: 0.005754987709224224] [G loss: 5.474653720855713] Epoch [208/300] [D loss: 0.0024513285607099533] [G loss: 6.821662425994873] Epoch [209/300] [D loss: 0.005051593761891127] [G loss: 8.622801780700684] Epoch [210/300] [D loss: 0.2648685872554779] [G loss: 1.4338374137878418] Epoch [211/300] [D loss: 0.06582126766443253] [G loss: 4.042891502380371] Epoch [212/300] [D loss: 0.033716216683387756] [G loss: 3.6866607666015625] Epoch [213/300] [D loss: 0.008300993591547012] [G loss: 5.592546463012695] Epoch [214/300] [D loss: 0.10640338063240051] [G loss: 3.440943479537964] Epoch [215/300] [D loss: 0.018705546855926514] [G loss: 8.040839195251465] Epoch [216/300] [D loss: 0.32254651188850403] [G loss: 1.023318886756897] Epoch [217/300] [D loss: 0.006875279359519482] [G loss: 5.205789566040039] Epoch [218/300] [D loss: 0.01632297970354557] [G loss: 6.327811241149902] Epoch [219/300] [D loss: 0.020900549367070198] [G loss: 6.634525299072266] Epoch [220/300] [D loss: 0.011139878071844578] [G loss: 7.300896644592285] Epoch [221/300] [D loss: 0.01837160252034664] [G loss: 5.964895248413086] Epoch [222/300] [D loss: 0.016974858939647675] [G loss: 4.413552284240723] Epoch [223/300] [D loss: 0.3439306914806366] [G loss: 5.5219573974609375] Epoch [224/300] [D loss: 0.047548823058605194] [G loss: 6.586645603179932] Epoch [225/300] [D loss: 0.03183538839221001] [G loss: 4.398618221282959] Epoch [226/300] [D loss: 0.0033374489285051823] [G loss: 7.412342071533203] Epoch [227/300] [D loss: 0.018537862226366997] [G loss: 5.484577655792236] Epoch [228/300] [D loss: 0.03582551330327988] [G loss: 3.6857614517211914] Epoch [229/300] [D loss: 0.11226078867912292] [G loss: 2.819861888885498] Epoch [230/300] [D loss: 0.002012553857639432] [G loss: 7.154722690582275] Epoch [231/300] [D loss: 0.00868014432489872] [G loss: 8.001018524169922] Epoch [232/300] [D loss: 0.0419110469520092] [G loss: 6.980061054229736] Epoch [233/300] [D loss: 0.006477241404354572] [G loss: 5.782578945159912] Epoch [234/300] [D loss: 0.0016205032588914037] [G loss: 10.428010940551758] Epoch [235/300] [D loss: 0.02312217839062214] [G loss: 4.159178733825684] Epoch [236/300] [D loss: 0.36001917719841003] [G loss: 2.4811325073242188] Epoch [237/300] [D loss: 0.005733223166316748] [G loss: 5.611016750335693] Epoch [238/300] [D loss: 0.008837449364364147] [G loss: 8.30731201171875] Epoch [239/300] [D loss: 0.011222743429243565] [G loss: 4.619396209716797] Epoch [240/300] [D loss: 0.0060098664835095406] [G loss: 6.022060394287109] Epoch [241/300] [D loss: 0.0011382169323042035] [G loss: 7.404472351074219] Epoch [242/300] [D loss: 0.3661719560623169] [G loss: 7.876453399658203] Epoch [243/300] [D loss: 0.0019019388128072023] [G loss: 7.3895263671875] Epoch [244/300] [D loss: 0.006632590666413307] [G loss: 5.541728973388672] Epoch [245/300] [D loss: 0.008930223993957043] [G loss: 5.2114691734313965] Epoch [246/300] [D loss: 0.016119416803121567] [G loss: 7.121890068054199] Epoch [247/300] [D loss: 0.001622633310034871] [G loss: 7.303770065307617] Epoch [248/300] [D loss: 0.005070182494819164] [G loss: 6.975015640258789] Epoch [249/300] [D loss: 0.04641895741224289] [G loss: 7.218448638916016] Epoch [250/300] [D loss: 0.01194002851843834] [G loss: 4.6930975914001465] Epoch [251/300] [D loss: 0.012792033143341541] [G loss: 4.67077112197876] Epoch [252/300] [D loss: 0.008810436353087425] [G loss: 5.938291072845459] Epoch [253/300] [D loss: 0.010516034439206123] [G loss: 4.816621780395508] Epoch [254/300] [D loss: 0.017264991998672485] [G loss: 8.856822967529297] Epoch [255/300] [D loss: 0.011463891714811325] [G loss: 6.232043743133545] Epoch [256/300] [D loss: 0.08137447386980057] [G loss: 2.598818778991699] Epoch [257/300] [D loss: 0.032363615930080414] [G loss: 4.790830135345459] Epoch [258/300] [D loss: 0.00863250344991684] [G loss: 7.292766571044922] Epoch [259/300] [D loss: 0.027235930785536766] [G loss: 6.844869613647461] Epoch [260/300] [D loss: 0.008849331177771091] [G loss: 5.027510643005371] Epoch [261/300] [D loss: 0.020822376012802124] [G loss: 4.600456714630127] Epoch [262/300] [D loss: 1.7667120695114136] [G loss: 3.4651100635528564] Epoch [263/300] [D loss: 0.022669170051813126] [G loss: 5.7553019523620605] Epoch [264/300] [D loss: 0.01582598127424717] [G loss: 4.149420261383057] Epoch [265/300] [D loss: 0.0035504011902958155] [G loss: 6.116427421569824] Epoch [266/300] [D loss: 0.07644154131412506] [G loss: 2.720405101776123] Epoch [267/300] [D loss: 0.030415533110499382] [G loss: 4.244810104370117] Epoch [268/300] [D loss: 0.020068874582648277] [G loss: 6.474517822265625] Epoch [269/300] [D loss: 0.002136750379577279] [G loss: 9.29329776763916] Epoch [270/300] [D loss: 0.00978941936045885] [G loss: 5.02622652053833] Epoch [271/300] [D loss: 0.08784317970275879] [G loss: 6.733256816864014] Epoch [272/300] [D loss: 0.009109925478696823] [G loss: 5.823270797729492] Epoch [273/300] [D loss: 0.008865194395184517] [G loss: 5.696066379547119] Epoch [274/300] [D loss: 0.029590584337711334] [G loss: 8.216507911682129] Epoch [275/300] [D loss: 0.0636298805475235] [G loss: 8.98292064666748] Epoch [276/300] [D loss: 0.004769572988152504] [G loss: 6.2220025062561035] Epoch [277/300] [D loss: 0.003883387427777052] [G loss: 6.5977911949157715] Epoch [278/300] [D loss: 0.04028937965631485] [G loss: 4.9343485832214355] Epoch [279/300] [D loss: 0.011857430450618267] [G loss: 6.440511703491211] Epoch [280/300] [D loss: 0.007019379176199436] [G loss: 5.2130351066589355] Epoch [281/300] [D loss: 0.022525882348418236] [G loss: 3.9527556896209717] Epoch [282/300] [D loss: 0.0071130781434476376] [G loss: 6.993907928466797] Epoch [283/300] [D loss: 0.003977011889219284] [G loss: 7.2447967529296875] Epoch [284/300] [D loss: 0.07062061131000519] [G loss: 5.2334771156311035] Epoch [285/300] [D loss: 0.01805986650288105] [G loss: 5.5015082359313965] Epoch [286/300] [D loss: 0.05663669481873512] [G loss: 6.766615390777588] Epoch [287/300] [D loss: 0.0032901568338274956] [G loss: 6.28628396987915] Epoch [288/300] [D loss: 0.3530406653881073] [G loss: 7.906818389892578] Epoch [289/300] [D loss: 0.004547123331576586] [G loss: 6.108604431152344] Epoch [290/300] [D loss: 0.010472457855939865] [G loss: 6.213746070861816] Epoch [291/300] [D loss: 0.016601260751485825] [G loss: 5.763346195220947] Epoch [292/300] [D loss: 0.04024907946586609] [G loss: 5.658637523651123] Epoch [293/300] [D loss: 0.07437323033809662] [G loss: 5.68184757232666] Epoch [294/300] [D loss: 0.08150847256183624] [G loss: 6.040549278259277] Epoch [295/300] [D loss: 0.0924491435289383] [G loss: 2.502917766571045] Epoch [296/300] [D loss: 0.0035814237780869007] [G loss: 7.250881195068359] Epoch [297/300] [D loss: 0.012245922349393368] [G loss: 6.780396461486816] Epoch [298/300] [D loss: 0.004009227734059095] [G loss: 5.833404064178467] Epoch [299/300] [D loss: 0.14272907376289368] [G loss: 7.528534889221191]
In [11]:
# 结果评估与可视化 def visualize_results(): X, y, mask = next(iter(dataloader)) X, y, mask = X.to(device), y.to(device), mask.to(device) generated_images = generator(X, mask) mask = mask.squeeze(1) generated_images = generated_images.squeeze(1) y = y.squeeze(1) final_output = generated_images plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.title('Masked NO2 Data') plt.imshow(X[0, 0].cpu().detach().numpy(), cmap='gray') plt.axis('off') plt.subplot(1, 3, 2) plt.title('Generated NO2 Data') plt.imshow(final_output[0].cpu().detach().numpy(), cmap='gray') plt.axis('off') plt.subplot(1, 3, 3) plt.title('Original NO2 Data') plt.imshow(y[0].cpu().detach().numpy(), cmap='gray') plt.axis('off') plt.tight_layout() plt.savefig('results_visualizationxxx.png') plt.close()
In [12]:
dataset_test = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/') test_loader = DataLoader(dataset_test, batch_size=64, shuffle=False, num_workers=8)
In [13]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [17]:
def cal_ioa(y_true, y_pred): # 计算平均值 mean_observed = np.mean(y_true) mean_predicted = np.mean(y_pred) # 计算IoA numerator = np.sum((y_true - y_pred) ** 2) denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2) IoA = 1 - (numerator / denominator) return IoA
In [21]:
device = 'cpu' generator = generator.to(device) eva_list = list() with torch.no_grad(): for X, y, mask in test_loader: X, y, mask = X.to(device), y.to(device), mask.to(device) generated_images = generator(X, mask) mask = mask.squeeze(1).cpu().detach().numpy() rev_mask = (mask==0)* 1 generated_images = generated_images.squeeze(1) real = y.squeeze(1).cpu().detach().numpy() * max_pixel_value final_output = generated_images.cpu().detach().numpy() final_output *= max_pixel_value # y_pred = final_output[rev_mask==1].tolist() # y_real = real[rev_mask==1].tolist() for i, sample in enumerate(generated_images): used_mask = rev_mask[i] data_label = real[i] * used_mask recon_no2 = final_output[i] * used_mask data_label = data_label[used_mask==1] recon_no2 = recon_no2[used_mask==1] mae = mean_absolute_error(data_label, recon_no2) rmse = np.sqrt(mean_squared_error(data_label, recon_no2)) mape = mean_absolute_percentage_error(data_label, recon_no2) r2 = r2_score(data_label, recon_no2) ioa = cal_ioa(data_label, recon_no2) r = np.corrcoef(data_label, recon_no2)[0, 1] eva_list_frame.append([mae, rmse, mape, r2, ioa, r])
In [15]:
import pandas as pd
In [24]:
pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()
Out[24]:
mae | rmse | mape | r2 | ioa | r | |
---|---|---|---|---|---|---|
count | 4739.000000 | 4739.000000 | 4739.000000 | 4739.000000 | 4739.000000 | 4739.000000 |
mean | 2.512144 | 3.430941 | 0.360515 | -0.342186 | 0.680466 | 0.578431 |
std | 1.184403 | 1.580097 | 0.338929 | 1.730534 | 0.267197 | 0.227716 |
min | 0.895772 | 1.189229 | 0.126946 | -42.147773 | -2.040257 | -0.542623 |
25% | 1.699544 | 2.389879 | 0.211304 | -0.480435 | 0.606881 | 0.457522 |
50% | 2.259834 | 3.125452 | 0.263535 | 0.094196 | 0.749986 | 0.620338 |
75% | 2.953516 | 3.983247 | 0.358762 | 0.421505 | 0.840619 | 0.745614 |
max | 10.477497 | 14.460713 | 4.314635 | 0.922679 | 0.981525 | 0.965753 |
In [11]:
# 保存训练好的模型 torch.save(generator, './models/GAN/generator.pt') torch.save(discriminator, './models/GAN/discriminator.pt')
In [14]:
test_imgs = [x for x in os.listdir('./test_img/') if 'img' in x] test_imgs.sort() test_masks = [x for x in os.listdir('./test_img/') if 'mask' in x] test_masks.sort() for img_npy, mask_npy in zip(test_imgs, test_masks): img = np.load(f'./test_img/{file}') img_in = torch.tensor(np.expand_dims(img, 0), dtype=torch.float32) mask = np.load(f'./test_img/{file}') mask_in = torch.tensor(np.expand_dims(mask, 0), dtype=torch.float32) out = generator(img_in, mask_in).detach().cpu().numpy() * max_pixel_value break
In [15]:
plt.imshow(out[0][0], cmap='RdYlGn_r')
Out[15]:
<matplotlib.image.AxesImage at 0x7f793dd49520>
In [ ]: