MAE_ATMO/torch_GAN_1d_baseline.ipynb

1011 lines
93 KiB
Plaintext
Raw Permalink Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "8252a3af-edbb-4dcf-967c-fe206e98ceab",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f8c3fd5-f70f-45d0-886a-c572895ffcee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"# 设置CUDA设备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d15a7732-b516-4054-905f-0e7d57e4a38e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum pixel value in the dataset: 107.49169921875\n"
]
}
],
"source": [
"# 定义函数来找到最大值\n",
"def find_max_pixel_value(image_dir):\n",
" max_pixel_value = 0.0\n",
" for filename in os.listdir(image_dir):\n",
" if filename.endswith('.npy'):\n",
" image_path = os.path.join(image_dir, filename)\n",
" image = np.load(image_path).astype(np.float32)\n",
" max_pixel_value = max(max_pixel_value, image[:, :, 0].max())\n",
" return max_pixel_value\n",
"\n",
"# 计算图像数据中的最大像素值\n",
"image_dir = './out_mat/96/train/' \n",
"max_pixel_value = 107.49169921875\n",
"\n",
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bc01ab26-2bd1-4adb-9d6d-5080e32ac1b5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f7a40059f90>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(42)\n",
"torch.random.manual_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "69ac2ad4-0e7c-42b8-b4cf-1149b447c3e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 8)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = np.concatenate([masked_image[:, :, :1], image[:, :, 1:]], axis=-1) # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (8, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/20/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "305f2522-bcb0-46b1-8cb1-ebb5a821db7b",
"metadata": {},
"outputs": [],
"source": [
"dataset = NO2Dataset(image_dir, mask_dir)\n",
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
"\n",
"# 生成器模型\n",
"class Generator(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(128),\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n",
" nn.Tanh(),\n",
" )\n",
"\n",
" def forward(self, x, mask):\n",
" x_encoded = self.encoder(x)\n",
" x_decoded = self.decoder(x_encoded)\n",
"\n",
" x_decoded = (x_decoded + 1) / 2\n",
"\n",
" x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n",
" return x_output\n",
"\n",
"# 判别器模型\n",
"class Discriminator(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(128),\n",
" nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
"# 将模型加载到GPU\n",
"generator = Generator().to(device)\n",
"discriminator = Discriminator().to(device)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "881b9b78-4e03-406c-8af3-d9a749350508",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generator is on: cuda:0\n",
"Discriminator is on: cuda:0\n"
]
}
],
"source": [
"# 定义优化器和损失函数\n",
"optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
"optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
"adversarial_loss = nn.BCELoss().to(device)\n",
"pixelwise_loss = nn.MSELoss().to(device)\n",
"\n",
"# 确认模型是否在GPU上\n",
"print(f\"Generator is on: {next(generator.parameters()).device}\")\n",
"print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b85fcbe7-ed61-40a7-8259-19f2fa71b056",
"metadata": {},
"outputs": [],
"source": [
"gen = torch.load('./models/GAN/generator.pth', map_location='cpu')\n",
"generator.load_state_dict(gen)\n",
"generator = generator.to('cpu')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "55a55f48-77ff-4ef7-9b79-9e98798d7c4d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dis = torch.load('./models/GAN/discriminator.pth', map_location='cpu')\n",
"discriminator.load_state_dict(dis)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d15f91ca-1bb1-464c-a937-8bdd13c6a1ee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [0/300] [D loss: 0.5392551422119141] [G loss: 0.9313964247703552]\n",
"Epoch [1/300] [D loss: 0.45492202043533325] [G loss: 1.5053317546844482]\n",
"Epoch [2/300] [D loss: 0.3420121669769287] [G loss: 1.4923288822174072]\n",
"Epoch [3/300] [D loss: 0.2960708737373352] [G loss: 1.955796480178833]\n",
"Epoch [4/300] [D loss: 0.40790891647338867] [G loss: 2.071624279022217]\n",
"Epoch [5/300] [D loss: 0.2747359275817871] [G loss: 1.917580485343933]\n",
"Epoch [6/300] [D loss: 0.47008591890335083] [G loss: 1.5003858804702759]\n",
"Epoch [7/300] [D loss: 0.19478999078273773] [G loss: 3.949864149093628]\n",
"Epoch [8/300] [D loss: 0.5340784192085266] [G loss: 0.8913870453834534]\n",
"Epoch [9/300] [D loss: 0.3194230794906616] [G loss: 1.861933946609497]\n",
"Epoch [10/300] [D loss: 0.22022968530654907] [G loss: 1.6654534339904785]\n",
"Epoch [11/300] [D loss: 0.3743482828140259] [G loss: 1.626413106918335]\n",
"Epoch [12/300] [D loss: 0.13774043321609497] [G loss: 3.187469482421875]\n",
"Epoch [13/300] [D loss: 0.4275822043418884] [G loss: 2.2269718647003174]\n",
"Epoch [14/300] [D loss: 0.22367843985557556] [G loss: 2.707667827606201]\n",
"Epoch [15/300] [D loss: 0.25350409746170044] [G loss: 1.6780041456222534]\n",
"Epoch [16/300] [D loss: 0.23647311329841614] [G loss: 1.6349072456359863]\n",
"Epoch [17/300] [D loss: 0.5604373812675476] [G loss: 2.025310754776001]\n",
"Epoch [18/300] [D loss: 0.4707084596157074] [G loss: 2.938746452331543]\n",
"Epoch [19/300] [D loss: 0.2135343998670578] [G loss: 1.438072919845581]\n",
"Epoch [20/300] [D loss: 0.08090662956237793] [G loss: 2.77827787399292]\n",
"Epoch [21/300] [D loss: 0.2995302677154541] [G loss: 2.239123821258545]\n",
"Epoch [22/300] [D loss: 0.45718979835510254] [G loss: 3.404195547103882]\n",
"Epoch [23/300] [D loss: 0.13950982689857483] [G loss: 2.7335846424102783]\n",
"Epoch [24/300] [D loss: 0.1814766675233841] [G loss: 2.7603776454925537]\n",
"Epoch [25/300] [D loss: 0.0551619790494442] [G loss: 4.066004276275635]\n",
"Epoch [26/300] [D loss: 0.1498052179813385] [G loss: 2.733922243118286]\n",
"Epoch [27/300] [D loss: 0.11236032098531723] [G loss: 3.215831756591797]\n",
"Epoch [28/300] [D loss: 0.4945942461490631] [G loss: 1.9661915302276611]\n",
"Epoch [29/300] [D loss: 0.0760776624083519] [G loss: 2.7212204933166504]\n",
"Epoch [30/300] [D loss: 0.19911707937717438] [G loss: 1.5356297492980957]\n",
"Epoch [31/300] [D loss: 0.0900304988026619] [G loss: 2.84740948677063]\n",
"Epoch [32/300] [D loss: 0.05910511314868927] [G loss: 4.071162223815918]\n",
"Epoch [33/300] [D loss: 0.0711229220032692] [G loss: 2.716848373413086]\n",
"Epoch [34/300] [D loss: 0.39897823333740234] [G loss: 1.5674937963485718]\n",
"Epoch [35/300] [D loss: 0.05552360787987709] [G loss: 3.754516124725342]\n",
"Epoch [36/300] [D loss: 0.5413599014282227] [G loss: 4.129798889160156]\n",
"Epoch [37/300] [D loss: 0.17664434015750885] [G loss: 4.226540565490723]\n",
"Epoch [38/300] [D loss: 0.15215317904949188] [G loss: 2.6023881435394287]\n",
"Epoch [39/300] [D loss: 0.07798739522695541] [G loss: 4.075980186462402]\n",
"Epoch [40/300] [D loss: 0.03936776518821716] [G loss: 4.37988805770874]\n",
"Epoch [41/300] [D loss: 0.2012120634317398] [G loss: 2.1987271308898926]\n",
"Epoch [42/300] [D loss: 0.05274203419685364] [G loss: 3.8458642959594727]\n",
"Epoch [43/300] [D loss: 0.13967157900333405] [G loss: 3.438344955444336]\n",
"Epoch [44/300] [D loss: 0.05800560116767883] [G loss: 2.941135883331299]\n",
"Epoch [45/300] [D loss: 0.14671097695827484] [G loss: 3.388277292251587]\n",
"Epoch [46/300] [D loss: 0.06439051032066345] [G loss: 2.9789438247680664]\n",
"Epoch [47/300] [D loss: 0.11101078987121582] [G loss: 2.6266937255859375]\n",
"Epoch [48/300] [D loss: 0.028554894030094147] [G loss: 4.042592525482178]\n",
"Epoch [49/300] [D loss: 0.3364626169204712] [G loss: 3.419842004776001]\n",
"Epoch [50/300] [D loss: 0.2501979470252991] [G loss: 3.319307804107666]\n",
"Epoch [51/300] [D loss: 0.2962917387485504] [G loss: 5.088353157043457]\n",
"Epoch [52/300] [D loss: 0.07700179517269135] [G loss: 3.231515884399414]\n",
"Epoch [53/300] [D loss: 0.4093267321586609] [G loss: 1.918235182762146]\n",
"Epoch [54/300] [D loss: 0.12105419486761093] [G loss: 2.3409922122955322]\n",
"Epoch [55/300] [D loss: 0.057456158101558685] [G loss: 4.047771453857422]\n",
"Epoch [56/300] [D loss: 0.250449538230896] [G loss: 1.9442336559295654]\n",
"Epoch [57/300] [D loss: 0.08125491440296173] [G loss: 2.7323458194732666]\n",
"Epoch [58/300] [D loss: 0.06671395897865295] [G loss: 3.081458330154419]\n",
"Epoch [59/300] [D loss: 0.06982511281967163] [G loss: 3.95278000831604]\n",
"Epoch [60/300] [D loss: 0.08973922580480576] [G loss: 3.9550158977508545]\n",
"Epoch [61/300] [D loss: 0.29226893186569214] [G loss: 2.2824535369873047]\n",
"Epoch [62/300] [D loss: 0.06800767779350281] [G loss: 4.67025089263916]\n",
"Epoch [63/300] [D loss: 0.017987174913287163] [G loss: 4.119121551513672]\n",
"Epoch [64/300] [D loss: 0.1278763711452484] [G loss: 4.481695652008057]\n",
"Epoch [65/300] [D loss: 0.12277506291866302] [G loss: 2.0188961029052734]\n",
"Epoch [66/300] [D loss: 0.10042040050029755] [G loss: 4.019499778747559]\n",
"Epoch [67/300] [D loss: 0.15092261135578156] [G loss: 3.0588033199310303]\n",
"Epoch [68/300] [D loss: 0.157196044921875] [G loss: 4.579256534576416]\n",
"Epoch [69/300] [D loss: 0.0256386436522007] [G loss: 4.309335708618164]\n",
"Epoch [70/300] [D loss: 0.011956267058849335] [G loss: 4.763312816619873]\n",
"Epoch [71/300] [D loss: 0.08460590243339539] [G loss: 5.456184387207031]\n",
"Epoch [72/300] [D loss: 0.07495025545358658] [G loss: 3.5078511238098145]\n",
"Epoch [73/300] [D loss: 0.13037167489528656] [G loss: 3.164292812347412]\n",
"Epoch [74/300] [D loss: 0.0830327719449997] [G loss: 5.159647464752197]\n",
"Epoch [75/300] [D loss: 0.4353921115398407] [G loss: 5.0652875900268555]\n",
"Epoch [76/300] [D loss: 0.02432486228644848] [G loss: 3.7066524028778076]\n",
"Epoch [77/300] [D loss: 0.2809848189353943] [G loss: 1.1604290008544922]\n",
"Epoch [78/300] [D loss: 0.7653636932373047] [G loss: 2.5745716094970703]\n",
"Epoch [79/300] [D loss: 0.041840165853500366] [G loss: 4.082228660583496]\n",
"Epoch [80/300] [D loss: 0.03992146998643875] [G loss: 4.9236321449279785]\n",
"Epoch [81/300] [D loss: 0.1003192886710167] [G loss: 2.683060646057129]\n",
"Epoch [82/300] [D loss: 0.1460535228252411] [G loss: 4.597597122192383]\n",
"Epoch [83/300] [D loss: 0.1408858597278595] [G loss: 1.8829160928726196]\n",
"Epoch [84/300] [D loss: 0.048089221119880676] [G loss: 3.1438090801239014]\n",
"Epoch [85/300] [D loss: 0.041934601962566376] [G loss: 3.298645257949829]\n",
"Epoch [86/300] [D loss: 0.1363355964422226] [G loss: 2.6124517917633057]\n",
"Epoch [87/300] [D loss: 0.03299988433718681] [G loss: 3.3402161598205566]\n",
"Epoch [88/300] [D loss: 0.22786922752857208] [G loss: 3.9778051376342773]\n",
"Epoch [89/300] [D loss: 0.021804900839924812] [G loss: 4.595890045166016]\n",
"Epoch [90/300] [D loss: 0.022495444864034653] [G loss: 4.2465901374816895]\n",
"Epoch [91/300] [D loss: 0.02908019907772541] [G loss: 6.379057884216309]\n",
"Epoch [92/300] [D loss: 0.6523040533065796] [G loss: 0.6009750962257385]\n",
"Epoch [93/300] [D loss: 0.007557982578873634] [G loss: 5.837783336639404]\n",
"Epoch [94/300] [D loss: 0.020063551142811775] [G loss: 4.044745445251465]\n",
"Epoch [95/300] [D loss: 0.003706925082951784] [G loss: 8.243224143981934]\n",
"Epoch [96/300] [D loss: 0.021942533552646637] [G loss: 4.662309169769287]\n",
"Epoch [97/300] [D loss: 0.005410192534327507] [G loss: 5.5743536949157715]\n",
"Epoch [98/300] [D loss: 0.07137680053710938] [G loss: 3.261455535888672]\n",
"Epoch [99/300] [D loss: 0.11327817291021347] [G loss: 3.817570686340332]\n",
"Epoch [100/300] [D loss: 0.04488084092736244] [G loss: 4.458094596862793]\n",
"Epoch [101/300] [D loss: 0.05757671222090721] [G loss: 3.695896625518799]\n",
"Epoch [102/300] [D loss: 0.04083157703280449] [G loss: 3.704172134399414]\n",
"Epoch [103/300] [D loss: 0.02816752716898918] [G loss: 4.322700023651123]\n",
"Epoch [104/300] [D loss: 0.026689285412430763] [G loss: 4.115890979766846]\n",
"Epoch [105/300] [D loss: 0.03571446239948273] [G loss: 4.080765724182129]\n",
"Epoch [106/300] [D loss: 0.020453810691833496] [G loss: 5.457651615142822]\n",
"Epoch [107/300] [D loss: 0.03774755448102951] [G loss: 5.34019136428833]\n",
"Epoch [108/300] [D loss: 0.0933525487780571] [G loss: 5.5797905921936035]\n",
"Epoch [109/300] [D loss: 0.024301748722791672] [G loss: 4.042290210723877]\n",
"Epoch [110/300] [D loss: 0.9034162759780884] [G loss: 5.52556848526001]\n",
"Epoch [111/300] [D loss: 0.0911281406879425] [G loss: 6.487083911895752]\n",
"Epoch [112/300] [D loss: 0.13892149925231934] [G loss: 3.0797510147094727]\n",
"Epoch [113/300] [D loss: 0.09627098590135574] [G loss: 3.104957103729248]\n",
"Epoch [114/300] [D loss: 0.007696065586060286] [G loss: 6.618851184844971]\n",
"Epoch [115/300] [D loss: 0.06528083980083466] [G loss: 3.4506514072418213]\n",
"Epoch [116/300] [D loss: 0.03879600390791893] [G loss: 3.3708789348602295]\n",
"Epoch [117/300] [D loss: 0.03395622968673706] [G loss: 6.2684736251831055]\n",
"Epoch [118/300] [D loss: 0.010569067671895027] [G loss: 5.944631099700928]\n",
"Epoch [119/300] [D loss: 0.024817001074552536] [G loss: 6.614266872406006]\n",
"Epoch [120/300] [D loss: 0.013173197396099567] [G loss: 6.226423263549805]\n",
"Epoch [121/300] [D loss: 0.06546411663293839] [G loss: 3.0585291385650635]\n",
"Epoch [122/300] [D loss: 0.01085597462952137] [G loss: 6.437295913696289]\n",
"Epoch [123/300] [D loss: 0.03522876650094986] [G loss: 4.0734052658081055]\n",
"Epoch [124/300] [D loss: 0.06875205039978027] [G loss: 4.0921711921691895]\n",
"Epoch [125/300] [D loss: 0.006707158405333757] [G loss: 5.244316577911377]\n",
"Epoch [126/300] [D loss: 0.03866109997034073] [G loss: 3.368199110031128]\n",
"Epoch [127/300] [D loss: 0.041117191314697266] [G loss: 4.484440326690674]\n",
"Epoch [128/300] [D loss: 0.0829429179430008] [G loss: 4.554262638092041]\n",
"Epoch [129/300] [D loss: 0.03219084441661835] [G loss: 5.4280924797058105]\n",
"Epoch [130/300] [D loss: 0.11037464439868927] [G loss: 5.89276647567749]\n",
"Epoch [131/300] [D loss: 0.029911085963249207] [G loss: 4.116299629211426]\n",
"Epoch [132/300] [D loss: 0.14276768267154694] [G loss: 2.059661626815796]\n",
"Epoch [133/300] [D loss: 0.06751281768083572] [G loss: 4.1591362953186035]\n",
"Epoch [134/300] [D loss: 0.06710615009069443] [G loss: 3.1725471019744873]\n",
"Epoch [135/300] [D loss: 0.015449777245521545] [G loss: 5.900448799133301]\n",
"Epoch [136/300] [D loss: 0.0017297605518251657] [G loss: 6.8876633644104]\n",
"Epoch [137/300] [D loss: 0.10661254078149796] [G loss: 3.035740613937378]\n",
"Epoch [138/300] [D loss: 0.04841696843504906] [G loss: 3.2598555088043213]\n",
"Epoch [139/300] [D loss: 0.13029193878173828] [G loss: 3.732114791870117]\n",
"Epoch [140/300] [D loss: 0.01422959566116333] [G loss: 4.98042106628418]\n",
"Epoch [141/300] [D loss: 0.15487617254257202] [G loss: 5.367415428161621]\n",
"Epoch [142/300] [D loss: 0.07540086656808853] [G loss: 4.3357768058776855]\n",
"Epoch [143/300] [D loss: 0.014456328004598618] [G loss: 4.569247245788574]\n",
"Epoch [144/300] [D loss: 0.012367785908281803] [G loss: 5.9672956466674805]\n",
"Epoch [145/300] [D loss: 0.05262265354394913] [G loss: 5.160377502441406]\n",
"Epoch [146/300] [D loss: 0.08042960613965988] [G loss: 3.7927441596984863]\n",
"Epoch [147/300] [D loss: 0.19245359301567078] [G loss: 3.8005473613739014]\n",
"Epoch [148/300] [D loss: 0.052174512296915054] [G loss: 5.132132053375244]\n",
"Epoch [149/300] [D loss: 0.4083835482597351] [G loss: 3.095195770263672]\n",
"Epoch [150/300] [D loss: 0.007787104230374098] [G loss: 7.455079078674316]\n",
"Epoch [151/300] [D loss: 0.011952079832553864] [G loss: 5.102141857147217]\n",
"Epoch [152/300] [D loss: 0.1612093597650528] [G loss: 3.7608675956726074]\n",
"Epoch [153/300] [D loss: 0.03018610179424286] [G loss: 3.8288230895996094]\n",
"Epoch [154/300] [D loss: 0.06719933450222015] [G loss: 4.006799697875977]\n",
"Epoch [155/300] [D loss: 0.0286514051258564] [G loss: 4.619848728179932]\n",
"Epoch [156/300] [D loss: 0.024552451446652412] [G loss: 4.437436580657959]\n",
"Epoch [157/300] [D loss: 0.011825334280729294] [G loss: 4.815029144287109]\n",
"Epoch [158/300] [D loss: 0.061660464853048325] [G loss: 7.883100509643555]\n",
"Epoch [159/300] [D loss: 0.041454415768384933] [G loss: 6.650402545928955]\n",
"Epoch [160/300] [D loss: 0.39040958881378174] [G loss: 8.09695053100586]\n",
"Epoch [161/300] [D loss: 0.0026854330208152533] [G loss: 8.107271194458008]\n",
"Epoch [162/300] [D loss: 0.16259369254112244] [G loss: 5.87791109085083]\n",
"Epoch [163/300] [D loss: 0.03663758188486099] [G loss: 4.121287822723389]\n",
"Epoch [164/300] [D loss: 0.009695476852357388] [G loss: 8.566814422607422]\n",
"Epoch [165/300] [D loss: 0.010842864401638508] [G loss: 7.692420959472656]\n",
"Epoch [166/300] [D loss: 0.010091769509017467] [G loss: 5.9158101081848145]\n",
"Epoch [167/300] [D loss: 0.005709683522582054] [G loss: 5.492888450622559]\n",
"Epoch [168/300] [D loss: 0.16688843071460724] [G loss: 3.3484747409820557]\n",
"Epoch [169/300] [D loss: 0.007227647118270397] [G loss: 6.33713960647583]\n",
"Epoch [170/300] [D loss: 0.007962928153574467] [G loss: 7.612416744232178]\n",
"Epoch [171/300] [D loss: 0.012646579183638096] [G loss: 4.420655250549316]\n",
"Epoch [172/300] [D loss: 0.01767764426767826] [G loss: 4.4174957275390625]\n",
"Epoch [173/300] [D loss: 0.006378074176609516] [G loss: 7.643772125244141]\n",
"Epoch [174/300] [D loss: 0.009910110384225845] [G loss: 5.333507061004639]\n",
"Epoch [175/300] [D loss: 0.004518002271652222] [G loss: 6.36816930770874]\n",
"Epoch [176/300] [D loss: 0.08845338225364685] [G loss: 4.761691570281982]\n",
"Epoch [177/300] [D loss: 0.038503680378198624] [G loss: 3.653679370880127]\n",
"Epoch [178/300] [D loss: 0.0021649880800396204] [G loss: 6.513932704925537]\n",
"Epoch [179/300] [D loss: 0.0054839057847857475] [G loss: 5.804437637329102]\n",
"Epoch [180/300] [D loss: 0.005088070873171091] [G loss: 5.903375148773193]\n",
"Epoch [181/300] [D loss: 0.024380924180150032] [G loss: 6.934257984161377]\n",
"Epoch [182/300] [D loss: 0.003647219855338335] [G loss: 9.193355560302734]\n",
"Epoch [183/300] [D loss: 0.8360736966133118] [G loss: 8.123100280761719]\n",
"Epoch [184/300] [D loss: 0.014819988049566746] [G loss: 4.3469648361206055]\n",
"Epoch [185/300] [D loss: 0.009622478857636452] [G loss: 5.201544761657715]\n",
"Epoch [186/300] [D loss: 0.023895107209682465] [G loss: 3.903581380844116]\n",
"Epoch [187/300] [D loss: 0.013679596595466137] [G loss: 8.605210304260254]\n",
"Epoch [188/300] [D loss: 0.0036324947141110897] [G loss: 6.411885738372803]\n",
"Epoch [189/300] [D loss: 0.006745172664523125] [G loss: 5.29392147064209]\n",
"Epoch [190/300] [D loss: 0.0007813140982761979] [G loss: 8.193427085876465]\n",
"Epoch [191/300] [D loss: 0.021813858300447464] [G loss: 4.648034572601318]\n",
"Epoch [192/300] [D loss: 0.025777161121368408] [G loss: 4.67152738571167]\n",
"Epoch [193/300] [D loss: 0.06395631283521652] [G loss: 7.985042095184326]\n",
"Epoch [194/300] [D loss: 0.034654516726732254] [G loss: 3.360792398452759]\n",
"Epoch [195/300] [D loss: 0.26737672090530396] [G loss: 6.765297889709473]\n",
"Epoch [196/300] [D loss: 0.010468905791640282] [G loss: 5.34564208984375]\n",
"Epoch [197/300] [D loss: 0.014369252137839794] [G loss: 5.097072124481201]\n",
"Epoch [198/300] [D loss: 0.003273996990174055] [G loss: 6.472024440765381]\n",
"Epoch [199/300] [D loss: 0.005874062888324261] [G loss: 8.4591646194458]\n",
"Epoch [200/300] [D loss: 0.005507076624780893] [G loss: 5.7223286628723145]\n",
"Epoch [201/300] [D loss: 0.16853176057338715] [G loss: 1.9387050867080688]\n",
"Epoch [202/300] [D loss: 0.0023364669177681208] [G loss: 8.370942115783691]\n",
"Epoch [203/300] [D loss: 0.003936069551855326] [G loss: 7.522141933441162]\n",
"Epoch [204/300] [D loss: 0.01826675795018673] [G loss: 4.6409101486206055]\n",
"Epoch [205/300] [D loss: 0.018070252612233162] [G loss: 6.2785234451293945]\n",
"Epoch [206/300] [D loss: 0.06540463864803314] [G loss: 4.250749111175537]\n",
"Epoch [207/300] [D loss: 0.005754987709224224] [G loss: 5.474653720855713]\n",
"Epoch [208/300] [D loss: 0.0024513285607099533] [G loss: 6.821662425994873]\n",
"Epoch [209/300] [D loss: 0.005051593761891127] [G loss: 8.622801780700684]\n",
"Epoch [210/300] [D loss: 0.2648685872554779] [G loss: 1.4338374137878418]\n",
"Epoch [211/300] [D loss: 0.06582126766443253] [G loss: 4.042891502380371]\n",
"Epoch [212/300] [D loss: 0.033716216683387756] [G loss: 3.6866607666015625]\n",
"Epoch [213/300] [D loss: 0.008300993591547012] [G loss: 5.592546463012695]\n",
"Epoch [214/300] [D loss: 0.10640338063240051] [G loss: 3.440943479537964]\n",
"Epoch [215/300] [D loss: 0.018705546855926514] [G loss: 8.040839195251465]\n",
"Epoch [216/300] [D loss: 0.32254651188850403] [G loss: 1.023318886756897]\n",
"Epoch [217/300] [D loss: 0.006875279359519482] [G loss: 5.205789566040039]\n",
"Epoch [218/300] [D loss: 0.01632297970354557] [G loss: 6.327811241149902]\n",
"Epoch [219/300] [D loss: 0.020900549367070198] [G loss: 6.634525299072266]\n",
"Epoch [220/300] [D loss: 0.011139878071844578] [G loss: 7.300896644592285]\n",
"Epoch [221/300] [D loss: 0.01837160252034664] [G loss: 5.964895248413086]\n",
"Epoch [222/300] [D loss: 0.016974858939647675] [G loss: 4.413552284240723]\n",
"Epoch [223/300] [D loss: 0.3439306914806366] [G loss: 5.5219573974609375]\n",
"Epoch [224/300] [D loss: 0.047548823058605194] [G loss: 6.586645603179932]\n",
"Epoch [225/300] [D loss: 0.03183538839221001] [G loss: 4.398618221282959]\n",
"Epoch [226/300] [D loss: 0.0033374489285051823] [G loss: 7.412342071533203]\n",
"Epoch [227/300] [D loss: 0.018537862226366997] [G loss: 5.484577655792236]\n",
"Epoch [228/300] [D loss: 0.03582551330327988] [G loss: 3.6857614517211914]\n",
"Epoch [229/300] [D loss: 0.11226078867912292] [G loss: 2.819861888885498]\n",
"Epoch [230/300] [D loss: 0.002012553857639432] [G loss: 7.154722690582275]\n",
"Epoch [231/300] [D loss: 0.00868014432489872] [G loss: 8.001018524169922]\n",
"Epoch [232/300] [D loss: 0.0419110469520092] [G loss: 6.980061054229736]\n",
"Epoch [233/300] [D loss: 0.006477241404354572] [G loss: 5.782578945159912]\n",
"Epoch [234/300] [D loss: 0.0016205032588914037] [G loss: 10.428010940551758]\n",
"Epoch [235/300] [D loss: 0.02312217839062214] [G loss: 4.159178733825684]\n",
"Epoch [236/300] [D loss: 0.36001917719841003] [G loss: 2.4811325073242188]\n",
"Epoch [237/300] [D loss: 0.005733223166316748] [G loss: 5.611016750335693]\n",
"Epoch [238/300] [D loss: 0.008837449364364147] [G loss: 8.30731201171875]\n",
"Epoch [239/300] [D loss: 0.011222743429243565] [G loss: 4.619396209716797]\n",
"Epoch [240/300] [D loss: 0.0060098664835095406] [G loss: 6.022060394287109]\n",
"Epoch [241/300] [D loss: 0.0011382169323042035] [G loss: 7.404472351074219]\n",
"Epoch [242/300] [D loss: 0.3661719560623169] [G loss: 7.876453399658203]\n",
"Epoch [243/300] [D loss: 0.0019019388128072023] [G loss: 7.3895263671875]\n",
"Epoch [244/300] [D loss: 0.006632590666413307] [G loss: 5.541728973388672]\n",
"Epoch [245/300] [D loss: 0.008930223993957043] [G loss: 5.2114691734313965]\n",
"Epoch [246/300] [D loss: 0.016119416803121567] [G loss: 7.121890068054199]\n",
"Epoch [247/300] [D loss: 0.001622633310034871] [G loss: 7.303770065307617]\n",
"Epoch [248/300] [D loss: 0.005070182494819164] [G loss: 6.975015640258789]\n",
"Epoch [249/300] [D loss: 0.04641895741224289] [G loss: 7.218448638916016]\n",
"Epoch [250/300] [D loss: 0.01194002851843834] [G loss: 4.6930975914001465]\n",
"Epoch [251/300] [D loss: 0.012792033143341541] [G loss: 4.67077112197876]\n",
"Epoch [252/300] [D loss: 0.008810436353087425] [G loss: 5.938291072845459]\n",
"Epoch [253/300] [D loss: 0.010516034439206123] [G loss: 4.816621780395508]\n",
"Epoch [254/300] [D loss: 0.017264991998672485] [G loss: 8.856822967529297]\n",
"Epoch [255/300] [D loss: 0.011463891714811325] [G loss: 6.232043743133545]\n",
"Epoch [256/300] [D loss: 0.08137447386980057] [G loss: 2.598818778991699]\n",
"Epoch [257/300] [D loss: 0.032363615930080414] [G loss: 4.790830135345459]\n",
"Epoch [258/300] [D loss: 0.00863250344991684] [G loss: 7.292766571044922]\n",
"Epoch [259/300] [D loss: 0.027235930785536766] [G loss: 6.844869613647461]\n",
"Epoch [260/300] [D loss: 0.008849331177771091] [G loss: 5.027510643005371]\n",
"Epoch [261/300] [D loss: 0.020822376012802124] [G loss: 4.600456714630127]\n",
"Epoch [262/300] [D loss: 1.7667120695114136] [G loss: 3.4651100635528564]\n",
"Epoch [263/300] [D loss: 0.022669170051813126] [G loss: 5.7553019523620605]\n",
"Epoch [264/300] [D loss: 0.01582598127424717] [G loss: 4.149420261383057]\n",
"Epoch [265/300] [D loss: 0.0035504011902958155] [G loss: 6.116427421569824]\n",
"Epoch [266/300] [D loss: 0.07644154131412506] [G loss: 2.720405101776123]\n",
"Epoch [267/300] [D loss: 0.030415533110499382] [G loss: 4.244810104370117]\n",
"Epoch [268/300] [D loss: 0.020068874582648277] [G loss: 6.474517822265625]\n",
"Epoch [269/300] [D loss: 0.002136750379577279] [G loss: 9.29329776763916]\n",
"Epoch [270/300] [D loss: 0.00978941936045885] [G loss: 5.02622652053833]\n",
"Epoch [271/300] [D loss: 0.08784317970275879] [G loss: 6.733256816864014]\n",
"Epoch [272/300] [D loss: 0.009109925478696823] [G loss: 5.823270797729492]\n",
"Epoch [273/300] [D loss: 0.008865194395184517] [G loss: 5.696066379547119]\n",
"Epoch [274/300] [D loss: 0.029590584337711334] [G loss: 8.216507911682129]\n",
"Epoch [275/300] [D loss: 0.0636298805475235] [G loss: 8.98292064666748]\n",
"Epoch [276/300] [D loss: 0.004769572988152504] [G loss: 6.2220025062561035]\n",
"Epoch [277/300] [D loss: 0.003883387427777052] [G loss: 6.5977911949157715]\n",
"Epoch [278/300] [D loss: 0.04028937965631485] [G loss: 4.9343485832214355]\n",
"Epoch [279/300] [D loss: 0.011857430450618267] [G loss: 6.440511703491211]\n",
"Epoch [280/300] [D loss: 0.007019379176199436] [G loss: 5.2130351066589355]\n",
"Epoch [281/300] [D loss: 0.022525882348418236] [G loss: 3.9527556896209717]\n",
"Epoch [282/300] [D loss: 0.0071130781434476376] [G loss: 6.993907928466797]\n",
"Epoch [283/300] [D loss: 0.003977011889219284] [G loss: 7.2447967529296875]\n",
"Epoch [284/300] [D loss: 0.07062061131000519] [G loss: 5.2334771156311035]\n",
"Epoch [285/300] [D loss: 0.01805986650288105] [G loss: 5.5015082359313965]\n",
"Epoch [286/300] [D loss: 0.05663669481873512] [G loss: 6.766615390777588]\n",
"Epoch [287/300] [D loss: 0.0032901568338274956] [G loss: 6.28628396987915]\n",
"Epoch [288/300] [D loss: 0.3530406653881073] [G loss: 7.906818389892578]\n",
"Epoch [289/300] [D loss: 0.004547123331576586] [G loss: 6.108604431152344]\n",
"Epoch [290/300] [D loss: 0.010472457855939865] [G loss: 6.213746070861816]\n",
"Epoch [291/300] [D loss: 0.016601260751485825] [G loss: 5.763346195220947]\n",
"Epoch [292/300] [D loss: 0.04024907946586609] [G loss: 5.658637523651123]\n",
"Epoch [293/300] [D loss: 0.07437323033809662] [G loss: 5.68184757232666]\n",
"Epoch [294/300] [D loss: 0.08150847256183624] [G loss: 6.040549278259277]\n",
"Epoch [295/300] [D loss: 0.0924491435289383] [G loss: 2.502917766571045]\n",
"Epoch [296/300] [D loss: 0.0035814237780869007] [G loss: 7.250881195068359]\n",
"Epoch [297/300] [D loss: 0.012245922349393368] [G loss: 6.780396461486816]\n",
"Epoch [298/300] [D loss: 0.004009227734059095] [G loss: 5.833404064178467]\n",
"Epoch [299/300] [D loss: 0.14272907376289368] [G loss: 7.528534889221191]\n"
]
}
],
"source": [
"# 开始训练\n",
"epochs = 300\n",
"for epoch in range(epochs):\n",
" for i, (X, y, mask) in enumerate(dataloader):\n",
" # 将数据移到 GPU 上\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" \n",
" valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n",
" fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n",
"\n",
" # 生成器生成图像\n",
" optimizer_G.zero_grad()\n",
" generated_images = generator(X, mask)\n",
" g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * pixelwise_loss(\n",
" generated_images, y)\n",
" g_loss.backward()\n",
" optimizer_G.step()\n",
"\n",
" # 判别器训练\n",
" optimizer_D.zero_grad()\n",
" real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n",
" fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n",
" d_loss = 0.5 * (real_loss + fake_loss)\n",
" d_loss.backward()\n",
" optimizer_D.step()\n",
"\n",
" print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n",
"\n",
"# 保存训练好的模型\n",
"torch.save(generator.state_dict(), './models/GAN/generator.pth')\n",
"torch.save(discriminator.state_dict(), './models/GAN/discriminator.pth')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "18482f18-a9cd-49cb-a63a-85725cc4088a",
"metadata": {},
"outputs": [],
"source": [
"# 结果评估与可视化\n",
"def visualize_results():\n",
" \n",
" X, y, mask = next(iter(dataloader))\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" generated_images = generator(X, mask)\n",
"\n",
" mask = mask.squeeze(1)\n",
" generated_images = generated_images.squeeze(1)\n",
" y = y.squeeze(1)\n",
"\n",
" final_output = generated_images\n",
"\n",
" plt.figure(figsize=(15, 5))\n",
" plt.subplot(1, 3, 1)\n",
" plt.title('Masked NO2 Data')\n",
" plt.imshow(X[0, 0].cpu().detach().numpy(), cmap='gray')\n",
" plt.axis('off')\n",
"\n",
" plt.subplot(1, 3, 2)\n",
" plt.title('Generated NO2 Data')\n",
" plt.imshow(final_output[0].cpu().detach().numpy(), cmap='gray')\n",
" plt.axis('off')\n",
"\n",
" plt.subplot(1, 3, 3)\n",
" plt.title('Original NO2 Data')\n",
" plt.imshow(y[0].cpu().detach().numpy(), cmap='gray')\n",
" plt.axis('off')\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig('results_visualizationxxx.png')\n",
" plt.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a02c7b46-4c53-4fff-b130-a82412f9cf06",
"metadata": {},
"outputs": [],
"source": [
"dataset_test = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
"test_loader = DataLoader(dataset_test, batch_size=64, shuffle=False, num_workers=8)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "13e2180c-5615-4610-a041-6da5f5c69a5d",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3a5533e4-f24d-41ae-8de0-0ab2a383d38f",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "c1dd78b5-c66e-45f4-ab0f-f4b9c6f08cd2",
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'\n",
"generator = generator.to(device)\n",
"eva_list = list()\n",
"with torch.no_grad():\n",
" for X, y, mask in test_loader:\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" generated_images = generator(X, mask)\n",
" mask = mask.squeeze(1).cpu().detach().numpy()\n",
" rev_mask = (mask==0)* 1\n",
" generated_images = generated_images.squeeze(1)\n",
" real = y.squeeze(1).cpu().detach().numpy() * max_pixel_value\n",
" final_output = generated_images.cpu().detach().numpy()\n",
" final_output *= max_pixel_value\n",
" # y_pred = final_output[rev_mask==1].tolist()\n",
" # y_real = real[rev_mask==1].tolist()\n",
" for i, sample in enumerate(generated_images):\n",
" used_mask = rev_mask[i]\n",
" data_label = real[i] * used_mask\n",
" recon_no2 = final_output[i] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label, recon_no2)\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1ca21b33-753f-49ee-93d8-ede92e100b5a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "edc3aaa3-c9b3-4094-9dea-e981f582ac09",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>2.512144</td>\n",
" <td>3.430941</td>\n",
" <td>0.360515</td>\n",
" <td>-0.342186</td>\n",
" <td>0.680466</td>\n",
" <td>0.578431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>1.184403</td>\n",
" <td>1.580097</td>\n",
" <td>0.338929</td>\n",
" <td>1.730534</td>\n",
" <td>0.267197</td>\n",
" <td>0.227716</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.895772</td>\n",
" <td>1.189229</td>\n",
" <td>0.126946</td>\n",
" <td>-42.147773</td>\n",
" <td>-2.040257</td>\n",
" <td>-0.542623</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.699544</td>\n",
" <td>2.389879</td>\n",
" <td>0.211304</td>\n",
" <td>-0.480435</td>\n",
" <td>0.606881</td>\n",
" <td>0.457522</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>2.259834</td>\n",
" <td>3.125452</td>\n",
" <td>0.263535</td>\n",
" <td>0.094196</td>\n",
" <td>0.749986</td>\n",
" <td>0.620338</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>2.953516</td>\n",
" <td>3.983247</td>\n",
" <td>0.358762</td>\n",
" <td>0.421505</td>\n",
" <td>0.840619</td>\n",
" <td>0.745614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>10.477497</td>\n",
" <td>14.460713</td>\n",
" <td>4.314635</td>\n",
" <td>0.922679</td>\n",
" <td>0.981525</td>\n",
" <td>0.965753</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 2.512144 3.430941 0.360515 -0.342186 0.680466 \n",
"std 1.184403 1.580097 0.338929 1.730534 0.267197 \n",
"min 0.895772 1.189229 0.126946 -42.147773 -2.040257 \n",
"25% 1.699544 2.389879 0.211304 -0.480435 0.606881 \n",
"50% 2.259834 3.125452 0.263535 0.094196 0.749986 \n",
"75% 2.953516 3.983247 0.358762 0.421505 0.840619 \n",
"max 10.477497 14.460713 4.314635 0.922679 0.981525 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.578431 \n",
"std 0.227716 \n",
"min -0.542623 \n",
"25% 0.457522 \n",
"50% 0.620338 \n",
"75% 0.745614 \n",
"max 0.965753 "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7db4bb93-198e-4163-a926-f3fabebe4510",
"metadata": {},
"outputs": [],
"source": [
"# 保存训练好的模型\n",
"torch.save(generator, './models/GAN/generator.pt')\n",
"torch.save(discriminator, './models/GAN/discriminator.pt')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "598577c8-7643-41a7-ad1e-ae9c5f2664f2",
"metadata": {},
"outputs": [],
"source": [
"test_imgs = [x for x in os.listdir('./test_img/') if 'img' in x]\n",
"test_imgs.sort()\n",
"test_masks = [x for x in os.listdir('./test_img/') if 'mask' in x]\n",
"test_masks.sort()\n",
"for img_npy, mask_npy in zip(test_imgs, test_masks):\n",
" img = np.load(f'./test_img/{file}')\n",
" img_in = torch.tensor(np.expand_dims(img, 0), dtype=torch.float32)\n",
" mask = np.load(f'./test_img/{file}')\n",
" mask_in = torch.tensor(np.expand_dims(mask, 0), dtype=torch.float32)\n",
" out = generator(img_in, mask_in).detach().cpu().numpy() * max_pixel_value\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "11a7a089-5691-455c-9b70-1c7a306be913",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f793dd49520>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGgCAYAAADsNrNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAACG9UlEQVR4nO2dd3gVZfbHT3olNySQhBaIiDRRFAQRu6i7lrVgXVQsuzZQsZdVd9VVLKuirqtrw95YxYYdFQRBmiBFAghIKAkJ6T259/7+AO+83++Qebmr+5uo5/M8PM+cvFPeeWfmDnO+7zknJhwOh0VRFEVR/p+J9bsDiqIoym8TfQEpiqIovqAvIEVRFMUX9AWkKIqi+IK+gBRFURRf0BeQoiiK4gv6AlIURVF8QV9AiqIoii/oC0hRFEXxBX0BKYqiKL7wP3sBPfroo9KrVy9JTk6W4cOHy7x58/5Xh1IURVF+gcT8L3LBvfbaa3LOOefI448/LsOHD5dJkybJlClTpLCwUHJycjy3DYVCsnnzZunQoYPExMT83F1TFEVR/seEw2GpqamRrl27Smysx3dO+H/AsGHDwuPGjYvYwWAw3LVr1/DEiROt2xYVFYVFRP/pP/2n//TfL/xfUVGR5+99vPzMNDc3y8KFC+XGG2+M/C02NlZGjRolc+bMca3f1NQkTU1NETu844Os74PHSFxKgoiIlFY3wTax8fhGTYhz7C7pidAWR19RnVLxlONisb1jcgIeiz7CclPjIsvljUFoi6eV61rCYKclYHtyHNq1LSGwQ8bm1CQp8bgt9zOEh5ZUWn9bo7PD6qZWaGsO4sHSEuLAbqR2HrOGFhyXjCRn+y21zdAWpH620B+yU/B6tdCJJdO9YG6floj9DtK2fJ48Zh2ScPsE+o+ceQ34+mQn48p0qV3Xj4+dRBs0GeeVnoj7rm3Gg/N92Mo7J+pbsT3D2H9lE+6b91XVTNeLzrua+saUNTj3ite13L5uC9h1tO/cNLwPzevPz0ddM96jDA9ZcwiPVd2E25t9b2wNtdm2/djYHkf3VUUDPY90bHP//GzW0bMXG8P3GZ5YAn2dtITafvYbWvl5wX31DCRHllsbWuTLC9+SDh06iBc/+wuorKxMgsGg5Obmwt9zc3Nl5cqVrvUnTpwot912m+vvcSkJkRdQLD3d/AKKNa5gfCrehPwCSuB2ujMTU7xfQMlpzgVJjMOLnUArt9ALKIleQPwj0+zxAoqh5zgpyhcQr58Y5+wwgZ6AMP0wJ/IPOd2IPGZBeggSjR/yBDqPWOqn69h0vfiNlchvBaPd1W8eFDoWvwwTLS8gc0xj6bySU6J7AfGx+T8nMcYKKfQCaqUfNNd9aHkBhegFZO6/kS4YvxSSEug/BHTeTU3eL6DEWOdeSaTnOoaOlUDPcnw8PX9p+J/PRI8XUEuC9wuIrwffl/Fx+JJIiHeO1doabLNNxN1vfgHxb1YcHTvWeP7i6AXEv428L6GXRhy9gIL0AjL3H0fPPe+Lf3tFxCqj/OwvoGi58cYb5aqrrorY1dXV0qNHD6luCrbpO2xtxIvfMTMlstxMd05WCl6g0nr8X9SeOWlg84/MN8V1YHfKd97ogzvjDf/BOlx3TXkD2EcWBMAekI3DP2szfh3M/KEqsnxgPm6bTi+zr7fgsddVNoI9qDOeZ4pxYwUtMmAV/W+vM91oPGZCD0VZfau0BX+F8PXj8+Bjc9/NB46/xPhrqr6V/gNAfSmhr7V0eqG1GDaPQUUj7qtbOm1LzzL/4PHXsPkFW09tHeiFVN7Y9rYi+DUl4n45bqhpNdr4RYnb8g/7mkp8vkyPgYhIRRP9B8L49eX/+fOxEqmjBZnJYPP1M6+J1z0o4v7fPcP75nEprXfulcQ4/lHH+5DPq7zOu2+uF5JhV5H3IpCE9zg/T3weDSF+GeKxzHHhMWqhfZUb1y/Y4H1OP/Kzv4A6deokcXFxUlJSAn8vKSmRvLw81/pJSUmSlJT0c3dDURRFaef87NOwExMTZciQITJ9+vTI30KhkEyfPl1GjBjxcx9OURRF+YXyP3HBXXXVVTJ27FgZOnSoDBs2TCZNmiR1dXVy3nnn/S8OpyiKovwC+Z+8gE4//XQpLS2VW2+9VYqLi2Xw4MHy4YcfuiYm7Coh8j12NjQfEZEUw9mbn4HuvMJtqMM8efQAsN9Ysxls9o/vnYvayW0Z3SPLL8UWQxv7ak/r3xHs61tQx3kzDvWN5Dj0n99/WEFk+ahK9Hc/ENwKdudU1KMO6Ir95plT35Y6Mws7ka6ynnSXFNesQxa5wZRajxlG7B9nGlrRd5xC4gprJagwIBk0iYC1ENexadJICmlZ7B83IZlMOqXw7CJsd+lmBPfVPDRNvpQa0roCdK15thJPfqmiiQJZxkw27jdJqNIiNBOqA/6k1LR4a0bmbDLW2GzwfRZIxmObx+Jrx1oI60usnbBuw+2m9sLrNtAgcr9dmhFtz5NnzPaUeO97NN2ybz5vGhaaIIH7CiS1PcnHpiv/yP9sEsL48eNl/Pjx/6vdK4qiKL9wNBecoiiK4gv6AlIURVF8wfc4oLZIjIuR2B/9kxRg1jcbNSDTd8xz1dnHyZoPs6QE42kOzcdI3k87OO0vLKiCNtZhTtm9L9gPrF0F9r9nloL95707gX1018Mjy49VvQ1tm2rRj2zGIYiIvLgPzjjc9+NPwe6bnRpZbqY5++yL59gCjiJ3BfPSmMcZGhL7v3lb1ptsmhFjakasF9W0eMdysGZkC+7lgE8Tzk6wrso78LFrGo5xIKltnY2DQTmTQW2zt//da98irBHhvjpSpoOt9XherPmwvuSOp3GWXdqiRYdxxayQFmYmTnBnwSCdhoJHud3GrmoeO4Nv8WYaw9qWtvtGcp9Lb7IFubKG5H0eFLjN2q+h34aadi0OSL+AFEVRFF/QF5CiKIriC+3WBZeTmhjJLWSbimt+kvIn5uvH7w/2tsZqsN9cs4n2jTu/LZQB9pqMzpHlw/Iroe3zDbVg996GbrEBWelg33Egfv5+ugGnP4cXfBZZPmD37tA2d8sasNllsCEVP4GzKF+b6WbjdDd7kIuTXW7sFmPXiRculxvty+YaSaep0dwX0wXBSSDZZcZToTlxI0+F5vM0t28iN9iCYpz+z65hPq/iWnwUR3TDafTmoWN5qm2i97RqJj2M67MLb2WFc+90ScPx7kDpvnh6/4ZqvO/Y5eqe/uxsn8Kp+jzSLG3fVnYddueRS87tumo7/Y2I+7yCobafAffUZ+97vpxSjUXxeLncdewOT+Tcb5Yp355T46mJ3d8/2Dor+gWkKIqi+IS+gBRFURRf0BeQoiiK4gvtVgM6rGcHSd5RXI6nY9bSNMVDDH95cwj9p+d/NB/sRw7fE+w/7YlTpQdkrwW775x1YBemOvrIJYOOgLa9O83DbT/DY6865VSwS+NRIzq0O/qhB/9nWmT5831OgraTd8cp3M+vKAf7/I9mg31cb5xO/vF6R69iP3FRFRYAzKIyBpxmhv3jjcG2p+LaUp7Y0t67avawg9ywuTAYudZdvnieVc0FB7nsQSej1AefM2NL5cIaEU+z90rdw6U5eEhY4/mOtEmvFENlKGVJS9Bb/0sm3WxTDRc7ZD3EOW8uLcDrusMD0C73KOfA48vT/VlntmmRPL3Z1IRcaXqSWbfkZ0A82/kZMAtIoursxqV1UfkFl8ZKY2yOGz+bjFm2JUTFA9tCv4AURVEUX9AXkKIoiuIL+gJSFEVRfKHdakD7dk6X1A7bSyukxmOJhUNS+uDKWfmRxdu+fhWa3j0RtZPK398Fdv5fDwP7KSqJsLQE/cprLnojsrz7nRXQNjsNfetzP8eqsKsfvx3sPlPOB/vt1o1gL+7WJbK85ZQHoO2EV28A+4MfPgT7gG7eafHzA055h03VqPn0COB4cwoa1hRqyImdRv50UzvhbTm9hy0VD/vHWZcx+8raCJeD5vN2lyOmmAqKoahvcVIvudL9h7z95VxanOG+dU5z1uf4Jr4+VZSKh2NWOGWNl67G48/XL5m25dgp1hQ45gxKqLd6pytqDuI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(out[0][0], cmap='RdYlGn_r')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44b063cf-d295-4dd8-b21b-1631ddae8ca1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}