{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3a2c33ed-8f78-4ce4-b5cd-7b7ffc5c8273", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "import os\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, Dataset\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" \n", "\n", "\n", "# 设置CUDA设备\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "85459fd1-6835-41cd-b645-553611c358e8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Maximum pixel value in the dataset: 107.49169921875\n" ] } ], "source": [ "max_pixel_value = 107.49169921875\n", "\n", "print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "3fc0918e-103c-40a3-93bc-6171e934a7e4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint before Generator is OK\n" ] } ], "source": [ "class NO2Dataset(Dataset):\n", " \n", " def __init__(self, image_dir, mask_dir):\n", " \n", " self.image_dir = image_dir\n", " self.mask_dir = mask_dir\n", " self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n", " self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n", " \n", " def __len__(self):\n", " \n", " return len(self.image_filenames)\n", " \n", " def __getitem__(self, idx):\n", " \n", " image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n", " mask_idx = np.random.choice(self.mask_filenames)\n", " mask_path = os.path.join(self.mask_dir, mask_idx)\n", "\n", " # 加载图像数据 (.npy 文件)\n", " image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n", "\n", " # 加载掩码数据 (.jpg 文件)\n", " mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n", "\n", " # 将掩码数据中非0值设为1,0值保持不变\n", " mask = np.where(mask != 0, 1.0, 0.0)\n", "\n", " # 保持掩码数据形状为 (96, 96, 1)\n", " mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n", "\n", " # 应用掩码\n", " masked_image = image.copy()\n", " masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n", "\n", " # cGAN的输入和目标\n", " X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n", " y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n", "\n", " # 转换形状为 (channels, height, width)\n", " X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n", "\n", " return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n", "\n", "# 实例化数据集和数据加载器\n", "train_dir = './out_mat/96/train/'\n", "valid_dir = './out_mat/96/valid/'\n", "test_dir = './out_mat/96/test/'\n", "mask_dir = './out_mat/96/mask/20/'\n", "\n", "print(f\"checkpoint before Generator is OK\")\n", "\n", "dataset = NO2Dataset(train_dir, mask_dir)\n", "train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n", "\n", "validset = NO2Dataset(valid_dir, mask_dir)\n", "val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)\n", "\n", "testset = NO2Dataset(test_dir, mask_dir)\n", "test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a60b7019-f231-4ccb-9195-c459f3a1521d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generator is on: cpu\n", "Discriminator is on: cpu\n" ] } ], "source": [ "# 生成器模型\n", "class Generator(nn.Module):\n", " \n", " def __init__(self):\n", " super(Generator, self).__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.BatchNorm2d(64),\n", " nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.BatchNorm2d(128),\n", " )\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(64),\n", " nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n", " nn.Tanh(),\n", " )\n", "\n", " def forward(self, x, mask):\n", " x_encoded = self.encoder(x)\n", " x_decoded = self.decoder(x_encoded)\n", "\n", "# x_decoded = (x_decoded + 1) / 2\n", "\n", "# x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n", " return x_decoded\n", "\n", "# 判别器模型\n", "class Discriminator(nn.Module):\n", " \n", " def __init__(self):\n", " super(Discriminator, self).__init__()\n", " self.model = nn.Sequential(\n", " nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.BatchNorm2d(64),\n", " nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.BatchNorm2d(128),\n", " nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n", " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", "# 将模型加载到GPU\n", "generator = Generator().to(device)\n", "discriminator = Discriminator().to(device)\n", "\n", "# 定义优化器和损失函数\n", "optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n", "optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n", "adversarial_loss = nn.BCELoss().to(device)\n", "\n", "# 确认模型是否在GPU上\n", "print(f\"Generator is on: {next(generator.parameters()).device}\")\n", "print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")" ] }, { "cell_type": "code", "execution_count": 17, "id": "fd2f0816-7301-4381-99e4-05905ce8b093", "metadata": {}, "outputs": [], "source": [ "def masked_mse_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * (1-mask)).sum() / (1-mask).sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 18, "id": "edf26d05-7fd6-404b-8f9d-3d2054593f7b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generator.load_state_dict(torch.load('./models/GAN/generator-1d.pth'))" ] }, { "cell_type": "code", "execution_count": null, "id": "645dd325-fc70-4234-8279-bc8cbc4c5dde", "metadata": {}, "outputs": [], "source": [ "# 开始训练\n", "epochs = 100\n", "for epoch in range(epochs):\n", " for i, (X, y, mask) in enumerate(train_loader):\n", " # 将数据移到 GPU 上\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " # print(f\"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}\") #checkpoint\n", " \n", " valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n", " fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n", "\n", " # 生成器生成图像\n", " optimizer_G.zero_grad()\n", " generated_images = generator(X, mask)\n", " g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * masked_mse_loss(\n", " generated_images, y, mask)\n", " g_loss.backward()\n", " optimizer_G.step()\n", "\n", " # 判别器训练\n", " optimizer_D.zero_grad()\n", " real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n", " fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n", " d_loss = 0.5 * (real_loss + fake_loss)\n", " d_loss.backward()\n", " optimizer_D.step()\n", "\n", " print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n", "\n", "# 保存训练好的模型\n", "torch.save(generator.state_dict(), './models/GAN/generator-1d.pth')\n", "torch.save(discriminator.state_dict(), './models/GAN/discriminator-1d.pth')" ] }, { "cell_type": "code", "execution_count": 10, "id": "9ab6849f-740a-49e4-9afd-23aafdc16725", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 11, "id": "8997686a-5812-4a92-972c-11376dfc1686", "metadata": {}, "outputs": [], "source": [ "def cal_ioa(y_true, y_pred):\n", " # 计算平均值\n", " mean_observed = np.mean(y_true)\n", " mean_predicted = np.mean(y_pred)\n", "\n", " # 计算IoA\n", " numerator = np.sum((y_true - y_pred) ** 2)\n", " denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n", " IoA = 1 - (numerator / denominator)\n", "\n", " return IoA" ] }, { "cell_type": "code", "execution_count": 19, "id": "78db8f70-3cdb-444f-8bb5-49126957a0b6", "metadata": {}, "outputs": [], "source": [ "eva_list = list()\n", "device = 'cpu'\n", "generator = generator.to(device)\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = generator(X, mask)\n", " rev_data = torch.squeeze(y * max_pixel_value, dim=1)\n", " rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)\n", " # todo: 这里需要只评估修补出来的模块\n", " data_label = rev_data * mask_rev\n", " data_label = data_label[mask_rev==1]\n", " recon_no2 = rev_recon * mask_rev\n", " recon_no2 = recon_no2[mask_rev==1]\n", " y_true = rev_data.flatten()\n", " y_pred = rev_recon.flatten()\n", " mae = mean_absolute_error(y_true, y_pred)\n", " rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n", " mape = mean_absolute_percentage_error(y_true, y_pred)\n", " r2 = r2_score(y_true, y_pred)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " eva_list.append([mae, rmse, mape, r2, ioa])" ] }, { "cell_type": "code", "execution_count": 14, "id": "ffc83338-9b0f-4934-9cba-406a5e1fb0ca", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 24, "id": "33eb7a86-242a-4488-9d9b-93548e72c98b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioa
count75.00000075.00000075.00000075.00000075.000000
mean1.6856092.8245790.2238520.8074830.894409
std0.5202850.6132990.0668270.1075660.024969
min1.1087562.0409640.1434610.3361930.812887
25%1.3381432.4626480.1761700.7809060.883027
50%1.5098212.6082270.2062740.8504170.900165
75%1.9631033.0675600.2576670.8667050.910917
max3.7294345.3632880.4614650.9122400.935183
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa\n", "count 75.000000 75.000000 75.000000 75.000000 75.000000\n", "mean 1.685609 2.824579 0.223852 0.807483 0.894409\n", "std 0.520285 0.613299 0.066827 0.107566 0.024969\n", "min 1.108756 2.040964 0.143461 0.336193 0.812887\n", "25% 1.338143 2.462648 0.176170 0.780906 0.883027\n", "50% 1.509821 2.608227 0.206274 0.850417 0.900165\n", "75% 1.963103 3.067560 0.257667 0.866705 0.910917\n", "max 3.729434 5.363288 0.461465 0.912240 0.935183" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()" ] }, { "cell_type": "code", "execution_count": 106, "id": "3844152b-c853-4311-a0a4-f57275ea78fc", "metadata": {}, "outputs": [], "source": [ "rst = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape', ascending=True)" ] }, { "cell_type": "code", "execution_count": 140, "id": "8f98d947-5a20-49f5-84a1-5ae5a0f4693e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioar
35231.8346642.5792960.1039110.8550560.9608410.931334
35441.5001941.9628850.1068160.8496880.9609700.924187
19521.7866392.2905600.1093830.7041220.9288290.869446
6022.2229572.9347340.1127510.7351780.9336390.877028
35312.0931652.7266980.1157550.7605300.9376620.889606
11141.9517482.5914480.1165780.6969700.9145010.843026
19792.0830012.6862310.1167620.5975120.8868770.791842
25682.6305873.6368900.1170440.4919520.8939280.833221
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa r\n", "3523 1.834664 2.579296 0.103911 0.855056 0.960841 0.931334\n", "3544 1.500194 1.962885 0.106816 0.849688 0.960970 0.924187\n", "1952 1.786639 2.290560 0.109383 0.704122 0.928829 0.869446\n", "602 2.222957 2.934734 0.112751 0.735178 0.933639 0.877028\n", "3531 2.093165 2.726698 0.115755 0.760530 0.937662 0.889606\n", "1114 1.951748 2.591448 0.116578 0.696970 0.914501 0.843026\n", "1979 2.083001 2.686231 0.116762 0.597512 0.886877 0.791842\n", "2568 2.630587 3.636890 0.117044 0.491952 0.893928 0.833221" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rst.head(8)" ] }, { "cell_type": "code", "execution_count": 141, "id": "a72b4ec5-b72d-43d0-8178-954f8edcb0dd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'1114', '1952', '2568', '3523', '602'}" ] }, "execution_count": 141, "metadata": {}, "output_type": "execute_result" } ], "source": [ "find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n", "find_ex" ] }, { "cell_type": "code", "execution_count": 159, "id": "819950e1-d16d-42e9-a0f9-57878ebc8f89", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for j in find_ex:\n", " ori = np.load(f'./test_img/{j}-real.npy')[0]\n", " pred = np.load(f'./test_img/{j}-gan-recom.npy')[0]\n", " mask = np.load(f'./test_img/{j}-mask.npy')\n", " plt.imshow(ori, cmap='RdYlGn_r')\n", " plt.gca().axis('off')\n", " plt.savefig(f'./test_img/out_fig/{j}-truth.png', bbox_inches='tight')\n", " plt.clf()\n", " \n", " plt.imshow(mask, cmap='gray')\n", " plt.gca().axis('off')\n", " plt.savefig(f'./test_img/out_fig/{j}-mask.png', bbox_inches='tight')\n", " plt.clf()\n", " \n", " mask_cp = np.where((1-mask) == 0, np.nan, (1-mask))\n", " plt.imshow(ori * mask_cp, cmap='RdYlGn_r')\n", " plt.gca().axis('off')\n", " plt.savefig(f'./test_img/out_fig/{j}-masked_ori.png', bbox_inches='tight')\n", " plt.clf()\n", " \n", " out = ori * mask + pred * (1 - mask)\n", " plt.imshow(out, cmap='RdYlGn_r')\n", " plt.gca().axis('off')\n", " plt.savefig(f'./test_img/out_fig/{j}-gan_out.png', bbox_inches='tight')\n", " plt.clf()" ] }, { "cell_type": "code", "execution_count": null, "id": "2926f4b3-0a15-48ae-bdaa-4e290e11b49d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }