870 lines
122 KiB
Plaintext
870 lines
122 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "3a2c33ed-8f78-4ce4-b5cd-7b7ffc5c8273",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Using device: cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" \n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# 设置CUDA设备\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"print(f\"Using device: {device}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "85459fd1-6835-41cd-b645-553611c358e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Maximum pixel value in the dataset: 107.49169921875\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"max_pixel_value = 107.49169921875\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "3fc0918e-103c-40a3-93bc-6171e934a7e4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"checkpoint before Generator is OK\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = idx % len(self.mask_filenames)\n",
|
|||
|
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32) / max_pixel_value # 形状为 (96, 96, 8)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = np.concatenate([masked_image[:, :, :1], image[:, :, 1:]], axis=-1) # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (8, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
|||
|
"\n",
|
|||
|
"# 实例化数据集和数据加载器\n",
|
|||
|
"train_dir = './out_mat/96/train/'\n",
|
|||
|
"valid_dir = './out_mat/96/valid/'\n",
|
|||
|
"test_dir = './out_mat/96/test/'\n",
|
|||
|
"mask_dir = './out_mat/96/mask/20/'\n",
|
|||
|
"\n",
|
|||
|
"print(f\"checkpoint before Generator is OK\")\n",
|
|||
|
"\n",
|
|||
|
"dataset = NO2Dataset(train_dir, mask_dir)\n",
|
|||
|
"train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"\n",
|
|||
|
"validset = NO2Dataset(valid_dir, mask_dir)\n",
|
|||
|
"val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)\n",
|
|||
|
"\n",
|
|||
|
"testset = NO2Dataset(test_dir, mask_dir)\n",
|
|||
|
"test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "a60b7019-f231-4ccb-9195-c459f3a1521d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Generator is on: cuda:0\n",
|
|||
|
"Discriminator is on: cuda:0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 生成器模型\n",
|
|||
|
"class Generator(nn.Module):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(Generator, self).__init__()\n",
|
|||
|
" self.encoder = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(8, 64, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|||
|
" nn.BatchNorm2d(64),\n",
|
|||
|
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|||
|
" nn.BatchNorm2d(128),\n",
|
|||
|
" )\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.ReLU(inplace=True),\n",
|
|||
|
" nn.BatchNorm2d(64),\n",
|
|||
|
" nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.Tanh(),\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x_encoded = self.encoder(x)\n",
|
|||
|
" x_decoded = self.decoder(x_encoded)\n",
|
|||
|
"\n",
|
|||
|
"# x_decoded = (x_decoded + 1) / 2\n",
|
|||
|
"\n",
|
|||
|
"# x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n",
|
|||
|
" return x_output\n",
|
|||
|
"\n",
|
|||
|
"# 判别器模型\n",
|
|||
|
"class Discriminator(nn.Module):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(Discriminator, self).__init__()\n",
|
|||
|
" self.model = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|||
|
" nn.BatchNorm2d(64),\n",
|
|||
|
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.LeakyReLU(0.2, inplace=True),\n",
|
|||
|
" nn.BatchNorm2d(128),\n",
|
|||
|
" nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n",
|
|||
|
" nn.Sigmoid(),\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return self.model(x)\n",
|
|||
|
"\n",
|
|||
|
"# 将模型加载到GPU\n",
|
|||
|
"generator = Generator().to(device)\n",
|
|||
|
"discriminator = Discriminator().to(device)\n",
|
|||
|
"\n",
|
|||
|
"# 定义优化器和损失函数\n",
|
|||
|
"optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
|||
|
"optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
|||
|
"adversarial_loss = nn.BCELoss().to(device)\n",
|
|||
|
"pixelwise_loss = nn.MSELoss().to(device)\n",
|
|||
|
"\n",
|
|||
|
"# 确认模型是否在GPU上\n",
|
|||
|
"print(f\"Generator is on: {next(generator.parameters()).device}\")\n",
|
|||
|
"print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "645dd325-fc70-4234-8279-bc8cbc4c5dde",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 开始训练\n",
|
|||
|
"epochs = 150\n",
|
|||
|
"for epoch in range(epochs):\n",
|
|||
|
" for i, (X, y, mask) in enumerate(train_loader):\n",
|
|||
|
" # 将数据移到 GPU 上\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" # print(f\"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}\") #checkpoint\n",
|
|||
|
" \n",
|
|||
|
" valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n",
|
|||
|
" fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n",
|
|||
|
"\n",
|
|||
|
" # 生成器生成图像\n",
|
|||
|
" optimizer_G.zero_grad()\n",
|
|||
|
" generated_images = generator(X, mask)\n",
|
|||
|
" g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * pixelwise_loss(\n",
|
|||
|
" generated_images, y)\n",
|
|||
|
" g_loss.backward()\n",
|
|||
|
" optimizer_G.step()\n",
|
|||
|
"\n",
|
|||
|
" # 判别器训练\n",
|
|||
|
" optimizer_D.zero_grad()\n",
|
|||
|
" real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n",
|
|||
|
" fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n",
|
|||
|
" d_loss = 0.5 * (real_loss + fake_loss)\n",
|
|||
|
" d_loss.backward()\n",
|
|||
|
" optimizer_D.step()\n",
|
|||
|
"\n",
|
|||
|
" print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n",
|
|||
|
"\n",
|
|||
|
"# 保存训练好的模型\n",
|
|||
|
"torch.save(generator.state_dict(), './models/GAN/generator.pth')\n",
|
|||
|
"torch.save(discriminator.state_dict(), './models/GAN/discriminator.pth')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "2d889a53-5415-4895-99ff-fc63745884a5",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "37f2df19-492c-4231-a388-13182ce515db",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "0b93a2e7-c4fb-4611-9967-4e33f9982ad5",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"generator = generator.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = generator(X, mask)\n",
|
|||
|
" rev_data = torch.squeeze(y * max_pixel_value, dim=1)\n",
|
|||
|
" rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" data_label = rev_data * mask_rev\n",
|
|||
|
" data_label = data_label[mask_rev==1]\n",
|
|||
|
" recon_no2 = rev_recon * mask_rev\n",
|
|||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
|||
|
" y_true = rev_data.flatten()\n",
|
|||
|
" y_pred = rev_recon.flatten()\n",
|
|||
|
" mae = mean_absolute_error(y_true, y_pred)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
|
|||
|
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
|
|||
|
" r2 = r2_score(y_true, y_pred)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"id": "e7b3323f-7116-4d4e-8483-2fd605e2fb57",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>0.399366</td>\n",
|
|||
|
" <td>1.246761</td>\n",
|
|||
|
" <td>0.047188</td>\n",
|
|||
|
" <td>0.963991</td>\n",
|
|||
|
" <td>0.939587</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.071295</td>\n",
|
|||
|
" <td>0.220616</td>\n",
|
|||
|
" <td>0.005035</td>\n",
|
|||
|
" <td>0.018081</td>\n",
|
|||
|
" <td>0.026807</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.348072</td>\n",
|
|||
|
" <td>1.073966</td>\n",
|
|||
|
" <td>0.040716</td>\n",
|
|||
|
" <td>0.813181</td>\n",
|
|||
|
" <td>0.719442</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>0.372762</td>\n",
|
|||
|
" <td>1.154684</td>\n",
|
|||
|
" <td>0.043751</td>\n",
|
|||
|
" <td>0.963354</td>\n",
|
|||
|
" <td>0.938713</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>0.388768</td>\n",
|
|||
|
" <td>1.207949</td>\n",
|
|||
|
" <td>0.045860</td>\n",
|
|||
|
" <td>0.966356</td>\n",
|
|||
|
" <td>0.943430</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>0.402351</td>\n",
|
|||
|
" <td>1.274836</td>\n",
|
|||
|
" <td>0.050051</td>\n",
|
|||
|
" <td>0.968919</td>\n",
|
|||
|
" <td>0.947026</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>0.959251</td>\n",
|
|||
|
" <td>2.966476</td>\n",
|
|||
|
" <td>0.066256</td>\n",
|
|||
|
" <td>0.972840</td>\n",
|
|||
|
" <td>0.956998</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa\n",
|
|||
|
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
|||
|
"mean 0.399366 1.246761 0.047188 0.963991 0.939587\n",
|
|||
|
"std 0.071295 0.220616 0.005035 0.018081 0.026807\n",
|
|||
|
"min 0.348072 1.073966 0.040716 0.813181 0.719442\n",
|
|||
|
"25% 0.372762 1.154684 0.043751 0.963354 0.938713\n",
|
|||
|
"50% 0.388768 1.207949 0.045860 0.966356 0.943430\n",
|
|||
|
"75% 0.402351 1.274836 0.050051 0.968919 0.947026\n",
|
|||
|
"max 0.959251 2.966476 0.066256 0.972840 0.956998"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "3c881732-b18f-4b6f-802a-1204d0ffa70f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"best_mape = 1\n",
|
|||
|
"best_img = None\n",
|
|||
|
"best_mask = None\n",
|
|||
|
"best_recov = None\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = generator(X, mask)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
|
|||
|
" if mape < best_mape:\n",
|
|||
|
" best_recov = rev_recon[i][0].numpy()\n",
|
|||
|
" best_mask = used_mask.numpy()\n",
|
|||
|
" best_img = sample[0].numpy()\n",
|
|||
|
" best_mape = mape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"id": "d3f6851c-eba3-48d5-bf6e-f94290b3d56e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.977856</td>\n",
|
|||
|
" <td>2.505394</td>\n",
|
|||
|
" <td>0.236766</td>\n",
|
|||
|
" <td>0.444234</td>\n",
|
|||
|
" <td>0.826185</td>\n",
|
|||
|
" <td>0.795505</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.966137</td>\n",
|
|||
|
" <td>1.156947</td>\n",
|
|||
|
" <td>0.075139</td>\n",
|
|||
|
" <td>0.309037</td>\n",
|
|||
|
" <td>0.112186</td>\n",
|
|||
|
" <td>0.109227</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.588599</td>\n",
|
|||
|
" <td>0.782554</td>\n",
|
|||
|
" <td>0.106112</td>\n",
|
|||
|
" <td>-5.779783</td>\n",
|
|||
|
" <td>-2.754070</td>\n",
|
|||
|
" <td>0.284676</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.195541</td>\n",
|
|||
|
" <td>1.551567</td>\n",
|
|||
|
" <td>0.187231</td>\n",
|
|||
|
" <td>0.300401</td>\n",
|
|||
|
" <td>0.781712</td>\n",
|
|||
|
" <td>0.735376</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.606092</td>\n",
|
|||
|
" <td>2.094027</td>\n",
|
|||
|
" <td>0.220013</td>\n",
|
|||
|
" <td>0.506733</td>\n",
|
|||
|
" <td>0.849549</td>\n",
|
|||
|
" <td>0.822590</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>2.658243</td>\n",
|
|||
|
" <td>3.338708</td>\n",
|
|||
|
" <td>0.266574</td>\n",
|
|||
|
" <td>0.658528</td>\n",
|
|||
|
" <td>0.899010</td>\n",
|
|||
|
" <td>0.876813</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>9.428754</td>\n",
|
|||
|
" <td>9.982598</td>\n",
|
|||
|
" <td>0.903847</td>\n",
|
|||
|
" <td>0.889351</td>\n",
|
|||
|
" <td>0.969285</td>\n",
|
|||
|
" <td>0.960868</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 1.977856 2.505394 0.236766 0.444234 0.826185 \n",
|
|||
|
"std 0.966137 1.156947 0.075139 0.309037 0.112186 \n",
|
|||
|
"min 0.588599 0.782554 0.106112 -5.779783 -2.754070 \n",
|
|||
|
"25% 1.195541 1.551567 0.187231 0.300401 0.781712 \n",
|
|||
|
"50% 1.606092 2.094027 0.220013 0.506733 0.849549 \n",
|
|||
|
"75% 2.658243 3.338708 0.266574 0.658528 0.899010 \n",
|
|||
|
"max 9.428754 9.982598 0.903847 0.889351 0.969285 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.795505 \n",
|
|||
|
"std 0.109227 \n",
|
|||
|
"min 0.284676 \n",
|
|||
|
"25% 0.735376 \n",
|
|||
|
"50% 0.822590 \n",
|
|||
|
"75% 0.876813 \n",
|
|||
|
"max 0.960868 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"id": "02827026-b34d-4859-a663-f799b88d4b54",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"real_test = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
|
|||
|
"real_loader = DataLoader(real_test, batch_size=1, shuffle=True, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "80fe1990-44c1-43c6-9c89-fba8f0f1b0ee",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])\n",
|
|||
|
"torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])\n",
|
|||
|
"torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])\n",
|
|||
|
"torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])\n",
|
|||
|
"torch.Size([1, 8, 96, 96]) torch.Size([1, 1, 96, 96]) torch.Size([1, 1, 96, 96])\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for batch_idx, (X, y, mask) in enumerate(real_loader):\n",
|
|||
|
" print(X.shape, y.shape, mask.shape)\n",
|
|||
|
" np.save(f'./test_img/{batch_idx}-img.npy', X[0])\n",
|
|||
|
" np.save(f'./test_img/{batch_idx}-mask.npy', mask[0])\n",
|
|||
|
" np.save(f'./test_img/{batch_idx}-real.npy', y[0])\n",
|
|||
|
" if batch_idx >=4:\n",
|
|||
|
" break"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 72,
|
|||
|
"id": "65241f09-7c50-48e1-a701-a3c4ba5e060c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_imgs = [x for x in os.listdir('./test_img/') if 'img' in x]\n",
|
|||
|
"test_imgs.sort()\n",
|
|||
|
"test_masks = [x for x in os.listdir('./test_img/') if 'mask' in x]\n",
|
|||
|
"test_masks.sort()\n",
|
|||
|
"for img_npy, mask_npy in zip(test_imgs, test_masks):\n",
|
|||
|
" img = np.load(f'./test_img/{img_npy}')\n",
|
|||
|
" img_in = torch.tensor(np.expand_dims(img, 0), dtype=torch.float32)\n",
|
|||
|
" mask = np.load(f'./test_img/{mask_npy}')\n",
|
|||
|
" mask_in = torch.tensor(np.expand_dims(mask, 0), dtype=torch.float32)\n",
|
|||
|
" out = generator(img_in, mask_in).detach().cpu().numpy() * max_pixel_value"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 73,
|
|||
|
"id": "d8594321-a526-4476-b72a-b377acdf10d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 74,
|
|||
|
"id": "449919c5-a05d-42d5-9461-cecb860f8d5d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.image.AxesImage at 0x7f65196d8940>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 74,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGgCAYAAADsNrNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABt1klEQVR4nO29e3Rd1XXvP3WOzkuvI8uyJctIWAES83ACweAYaNoGt/xS2sKF297cS25pml+edoLx/SWNewMdISFO2t9IKPmRpMlIKbk3hITevDNCbq5JSGjMywSCS7BNMPgpWbYsHb3OQ+fs3x8yOmt+l85a2pboPsD3M4bG0Dpr77XWXnsfLe35XXPOhiAIAiGEEEL+nYlFPQBCCCGvTrgAEUIIiQQuQIQQQiKBCxAhhJBI4AJECCEkErgAEUIIiQQuQIQQQiKBCxAhhJBI4AJECCEkErgAEUIIiYSXbAG64447ZNWqVZJOp2XdunXyyCOPvFRdEUIIeRnS8FLEgvvGN74hf/EXfyFf/OIXZd26dXLbbbfJvffeK7t375bly5c7z61UKnL48GFpbW2VhoaGxR4aIYSQl5ggCGRsbEx6enokFnO85wQvARdffHGwcePG2XK5XA56enqCbdu2ec89cOBAICL84Q9/+MOfl/nPgQMHnH/vG2WRKRaLsnPnTtm6devsZ7FYTDZs2CA7duywji8UClIoFGbLwckXskd3f1paWjMzbZbz6pwAXtpKlWp9sTKt6qYrFTi27Bx/3PPWVTLam4a2mhMpVa4Euu+K6HEnY3r6G2NxVY5J7bHEYJz5cgn61n2NFfUcTk1X5/yf/m1I1T13Qh/b2ZRQ5WeHJ1X5pktXqHKuoK+7bAwFx5WH2xGHS/7u3hFx0ZbSc/jCaHXs7zl/mfNc7Gtpxm2RxkfDvJSyviyZLOkPjk3pCy3pKZJSRR+P7ZlMTetKvA5kAjorQ19QlJijPRxnutE9Z9hXJhGvcaTIVAnnSJ+bK+j6coDz0FCzfgrmoOSaYBFpS9ce51y4/m7gOMOcKyKSgBtsFj2XYV2nq625KBrn+54Ts356siQ/e/e3pbW11dn+oi9Ax44dk3K5LF1dXerzrq4ueeaZZ6zjt23bJh/72Mesz1taM9La9uICpK80gD/sRaNYLLsXIFygkMYG9xeqaCw6uAC1JNKqvOAFyPFgYl3jtD4X/9BXirrcYExDoimp6uKwKjTCAhSb0uVMiz6/mJj/AiSeBSie0X0hjWk9h7FitcE0jAvBvjKeBcj1hcM/BAEsQKmYvtAYLECxEAtQOeQCVPIsQNiXsz04OJkItwAlHQtQGRYg7Ksxrr+7DZ4FyKxvhDmoeP5yNy7iAoTjDHOuiEgirufYvD8xzwIUlPV1u9qa+/xqB2EW/BfxySiLvgCFZevWrbJly5bZci6Xk97eXok3NEq8YWZ4MVwUoDw9Dd9mB7jA+P7oY9+NxqIzXtJvCr4Fx/VGMxfmAmf94XYcKyIyDWPB6zLbGy3oL/aVZ7ar8gP7x1T5DV0tzrHg36R8qfbY8Vj8L+s/rl6iyv/v/3leHw/3/sP/12tq9oVfNtd/+jNjw2eh9hyWy+5FAf/zxP/ucSz4hmR+ubFt33/B1h8K/EcIGjT/6y3CH7AidDYF849t4QJkH199ALAvfGvx/aEeL9Z+Q7La8tx8vE7EGouxEiRj+Efe3RfOWRIWCdcbVLIBz8W+9N83vC7fopI0Tsd7Z2FUV+b5p27RF6DOzk6Jx+MyODioPh8cHJTu7m7r+FQqJalUyvqcEELIK5tF34adTCblwgsvlO3bt89+VqlUZPv27bJ+/frF7o4QQsjLlJfEBLdlyxa5/vrrZe3atXLxxRfLbbfdJhMTE/KOd7zjpeiOEELIy5CXZAH6T//pP8nQ0JDcfPPNMjAwIOeff77cd9991sYEF6YGlIxpcb9UKapyU2PT7O/58pSz3UoA4j1oQMm4nhLUVsz6pkCbDm39qLZmMFfbuEHC1IyKcKylN0Hbvs0WJlnYSfb0Ma1tHRorqPLQpJ7/cqB3uqB24tKp0WSdBxs1akL/z4ZVtRsT94YHtG9j29g3ygAl2NCSdii4LtcHEXtOyu7NmWrs1k5Cj23et+kAJ8LUYlBX8QnP8QbUL3RXxVJtTQk1IGy77NDFsC0REVNKycCEY1sIzhlqRhnY/WfWow6Dc+ZrC/VC3P23mKDehOgdd7W1QhF9XRWfwHqSl2wTwqZNm2TTpk0vVfOEEEJe5jAWHCGEkEjgAkQIISQSIvcDqkUinpJkfEb7eVELepGGBrS/VuunA61PoMaD2FoK2qi1lmJqQI0eQ7/t9e+OVuDqC+twnOj34/MbMusfPZxTdX/y2k5Vfs0SrXWdv0w7eKItHqelqbG2foGgbjY57fPH0GWXlmLpLh7/GQQkICmZDoEekzf25fVgR4HK1Zbj2LmOH83rZ8nlE+Pz4sd6n6+IpetMO+o814Xjtn1gHHg8OFEb8Ugl6jlsSbr/5qB2gtoXvhdge65nzfI/i/k0O7d/2gkMVeLoy2S+fkB8AyKEEBIJXIAIIYREQt2a4P7zD78/G4Ps/rXnq7oTK05TZTPeUWuiQ9Xhlu1A0MSmtxyjeS8Zx23B1TLGmUMsM5knEOp0AFutDXMGmtiwLSv2m1WubbL7+9/vgWOdwwxtTjLDf0HUH/+5HtsWmtVch1vbw6Heup3Qdqpx/iYebKviCQCKYCgeE98WYXsw7tA7iNkexmH04Q1hY5lBa2/D9vYFIZ7QvGReJ5rUkp6Yj16TW6z21mrf/VkCceas7f7wAd5e85nHZxqfKzRhW49KDOdQV7vMoK7nbr7bsPkGRAghJBK4ABFCCIkELkCEEEIioW41oOvOyc7mmXnqyq+oujXf/gtVHut/3ezvZQhBg5oO0gBrMGpEyLShKUGUEmubdTLmztGD+HIZhSFf1toX9m1ut3TpDSLuPDgi9lZOZ0IzT+gdlCcwLQuOFW3g5vFo//ZpOLgFFfUlvE5z7L6Eclj2hRyyw+cENet8IfbDYrbvC4fjC59jhaGBcC5mX5ZO49GqMMWCONKd2OFwdD2mUPDNIYbPMceK98OdlcoO6dQEz2mY9Bv4zHrSNQlG+XFpPr6wPSYNvlhUJ+EbECGEkEjgAkQIISQSuAARQgiJhLrVgNZ1rZKWtplQPP03DKu6b5z9VVW+cPhds7+f1vxaVZeON6lyvjypytPgJ5QvY5ptDIFT20aan9ZtTYMGNAn11vmgIU1NV8tow8YQQ+jnMwkhUQrTtTUKvCYMh5MGmzTax32+OiZoz/alyUbdJgFmfzQ1m2PxmaEticFKI+ELVV+dh3FPyCAMsV9xpCmf63gxbPVxTHsdoNORs2nrfMS8v6iNFKHxDNaHjG/k9WEyj4Xn0pdiQaUx9+h5VlgfT70dBihm1Ll1MJwi9EdD8PjWRLU91BJ9em5+GvVDd8p1U9vCa7a0R4d2WAu+ARFCCIkELkCEEEIigQsQIYSQSKhbDag/u0ba2ppFRKThvAFVd87qX6nymftGq4XxB1Rd5dIrnP2UQePBMtpUUTMyyRV1OvCKuFNwY4pnpOLQaSancdy6HjUfvA7TnIuSgOUz5MkIXBKfzdu0DetzU764ZGiLd0e619cSQl+wzhWRmMeHLKXGgv4w+lj0Z0L/DJ/tPoxvjzdVNepwGLJf+eZYQpmzb1MXm2ss9vG175GVEgEOtdIawD2IG35B2I+vbZ82Zfk/GXOG2hTqeRjrDfv2fSfGHPphwRuLD+YB7ifqV2nD3wldhFxZJ+Yb149vQIQQQiKBCxAhhJBI4AJECCEkEupWAzo0/qzkYhkREen8xhOqrrVTp4ge+Nv7Zn9f8aV3qjpM522m756rvjGmIzehf81IoepHhJrOxDTGX1NFy86/EH+aSY9oYPc176bnyFXj9hPCNNiVhtqdoXkbY6hh20hIWcfZt29O7HTttePpob/SpMfPx3cdlkZkzItLfxCx9Y6WpPv/TDw/Yxxvz5FHEIR4bGH8gmztCjQdz7PhiufmS/c9VXH3hbHfrL6NG2Zfs267GW4uPgs
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.imshow(out[0][0], cmap='RdYlGn_r')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 75,
|
|||
|
"id": "c9affe45-bf88-4227-9eeb-55bd0dd8532f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_real = [x for x in os.listdir('./test_img/') if 'real' in x]\n",
|
|||
|
"test_real.sort()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 77,
|
|||
|
"id": "ac9cb241-e3bb-44ed-aadd-b87439ae3d9b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"y_real = np.load(f'./test_img/{test_real[4]}')*max_pixel_value"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 78,
|
|||
|
"id": "4e8425d2-e9a9-4200-940f-3aa14c36367a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.image.AxesImage at 0x7f65196b6370>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 78,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGgCAYAAADsNrNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABvRElEQVR4nO29fbBdV3nf/5z3c4/um16QZMUSFim/mrf8AjYYYaZNg6aelDRQPGmZMa1DMqVJ5GCjmRCcxu7gxIgkM4lLxoHCpA5MITSeKSQlUzL8REMHYjB2CsEhsZ3ixAYsGSPpXt173s/Zvz8k37PWZ92z1t2W3H1sfz8zmtHS3nvttdc+9y6d57ue51vKsiwzIYQQ4v8y5aIHIIQQ4vmJFiAhhBCFoAVICCFEIWgBEkIIUQhagIQQQhSCFiAhhBCFoAVICCFEIWgBEkIIUQhagIQQQhSCFiAhhBCF8IwtQHfeeadddtll1mw27aqrrrJ77733mbqVEEKIZyGlZ6IW3H/9r//V/s2/+Tf2oQ99yK666iq744477O6777YHH3zQdu/eHb12PB7bd7/7XVtYWLBSqXSxhyaEEOIZJssyO3v2rO3bt8/K5cj3nOwZ4DWveU125MiRjfZoNMr27duXHTt2LHntY489lpmZ/uiP/uiP/jzL/zz22GPR3/dVu8j0+327//777eabb974t3K5bIcPH7Z77rknOL/X61mv19toZ+e/kP39x95mi636uX/cNudfVK347YVtk7/PNb1DpWrdP7fWiD/AcOC3s7HfrjtjYd/rp/12GdPLb3TDntfMBrj3cOj0hf9F9Id+G3NUquDe8y/w260l52T0Xa15zfXRutceZf69R2O/vdr/Ps4fTRqYz86o69+6hHcLxpZ57bL5c1p25rhW9j8LhN+wdzYvSZzvz1PmPAvnpDv05+yJzgmv3cecDUYjrz0c+/M0du7VGfa9Y+VEpGANn7P+aDzlzHPEuhvi0rmqPydjBFRGaG+r+J8t9312Rv6cdIb+tWe6/s0HGEuj4g+8N3L6Rl9ttMt45u2NuDrB891bj/yuA8Y4XsOt+D7D49P7Ipwj9lXFg/D99fyPZRT32u5634698RO2sLAQveaiL0BPPvmkjUYj27Nnj/fve/bssb/5m78Jzj927Ji9973vDf59sVV3FiAsGlyA5p1fNK3UAhT/pWRD9J1nASp30OYChLc/9F9+1se9owsQFqt5PDcXoAUs4q3W9HHhuSoj/lKJL0BZvz31/AzzWcE6Wi0nFqCMvzimL0D15ALkP/diszXlzM3Pjy1ANfymblf9sXAB6uOXb2wBKg2nP/NmjAf+nFUSCxB/ubrwF1or5wLUwmfLXYBKQ/58+Nd2q/jsJBagkvO55RyMEgtQszm7C1Cee3GO2FctsQCVnuYCtHF94rN50RegvNx888129OjRjfbq6qrt37//3C/Y6vnB44fR8EvL6s7/qnguF5CB/z9u4y/PITof4g04//ssze/yj/EXOdscS4rq5PWU6v7/HDO0S1xYuTjy3u7YOM4RJxhdJZ6jVvbvPRh2p5xpVsdCudmHOH4+f1E7v3TMP8ZxpahXsKhjz4676IwwZ1ysOG4uQGWcP8z8/2C4z8VfUFysUgtS+MvT/4eu81uth1/UA/z2XOv7925U/b4G+PFp49uY+wuQfa/0Uv/Z8I+fDr4hTY6nvvEYvkl3h/HPYQ2LnfvRaPJYAi4KrRo/Z9PHwncX+7ZkFs4Df1028R+KlvNrpo3/fXDxG2eTm41i/4txuOgL0K5du6xSqdjJkye9fz958qTt3bs3OL/RaFijkQiLCSGEeM5x0bdh1+t1u+KKK+z48eMb/zYej+348eN26NChi307IYQQz1KekRDc0aNH7frrr7crr7zSXvOa19gdd9xh6+vr9va3v/2ZuJ0QQohnIc/IAvSv/tW/su9973t266232okTJ+yHf/iH7bOf/WywMSFG1h9adj4eWVr1dxRxo4E96ew+27HkHcrKiM1DzM+6fkza2K5iitygaXUVxyjG+9pHiX1VfE2iNA8R3N3tVPePUWMINJ/UZguvMwaO/XbT/HuvYucaBXiK//3SdA2I76M38jcwxMT4zRiOJ6JDrRw/l331x/HnGoz93WfNymRe+D4qpfiPFnf7DS21MWB6TJ26Cs+lrsaNBOPSdJ1npR/Xl1KaHXWCNuTFWnm6TkMo3g/QeTeiyLegTYX6Rbwd6DTV6doLNSDqRZRHFqHn8rOx0ufncvJ3yk3cVED9iJpR0791MC/uc40y/9r2YLqutkUJ6JnbhHDDDTfYDTfc8Ex1L4QQ4lmOasEJIYQoBC1AQgghCqHwPKBpjFf6Nu6fizGWmsiZCPKCnGQD6kNtP36aMTjJBE/oNNRlsvYk2TRDRjo31QeaT2Peb1N74XE3oZOaDhNNmfQatKnzOMfRV5Boml1YXlCjOpnD1LnMh+kM16LnU2spl6b3H+oy8ecinAdXM0ppPq42tRWYmOrlNyWSPdmmnsF8mTyJp8y9Yd9nmfSKvqkxrPWzqcdSWf5B8miO/JtyQrtiPhO7jrWXElUUmO8UJCEj+3Oh5uu7sWRtanR58+z4Of1+tz3lzHgS61Y1IH0DEkIIUQhagIQQQhTCzIbghifWbHh+j2D1Ur+g3Xgd207d8NKjftHHICQX1JHD1meE0WIhu6DeGr8ZM8SW2hrN/ppOSK6Sc5t1IgQ3cMJJQ4S5GGpiSRuGyUhQhqY0GSu3WXMrNMvdcEt3qrxOsD09R1+saVfF1nZ323UKhhoZ+hhmIxzHlnCG4Gx6CK6Gd8vjDBexZhpxDwdbthMhnHCLd/R0L+zG0BRDcPxR7CKqyevdkBy3TTdiFgEWbrsmHIsbsmMIlO9nZ9Ovy8hUA5ZpYsjNbdcTtRNX+359Sv7sNlF4mGPh+3QJw5CTfxhuMQanb0BCCCEKQQuQEEKIQtACJIQQohBmVgMan+rauH4uFsrNsiXEFysvmJSuL7WwLZdbtgntFqgRkaZTubu1zT+GcjmB9UPKnoHXuzpP6jkQZ+5n/hZx2gUMnONDlJihDsO4MfWKKnSYCvUnd1wodzNEWR9eS0uEVNmfucpEN+uMfG2LfRGW2qmVoC9Fygb1E+WJuL21PfDvFWhCNn2rLjUGmoaxRE1e3Mt76It9UyNIlX6hNODKNtzW26zGdQSWguE2bHcsZfSd1xeHLNT9DhrOr41QN/P/YRHXtmr+Z37efHcAbqWmTuOSdxt2e+Dr6cNs+hZ99l3Br0r3ozFOvLuNPrd0lhBCCHGR0QIkhBCiELQACSGEKISZ1YAG7YENzpdn7z3g2zE0Fv3YfH1lEses7PPL2Yye9PfBV7b7OkC5CTfWReo6qFfu6gDUcOYW/TYtElrLfpuxXGhAbhmaVLn/DJoDS9h0kX/j5qlQr+C9mP/CmHaqDI17L57Le1F/auDeHCv7i9lux3KEzEKNiHoUx+bem9pW2Dfi+H3/fZRhCc2S/GPHtoCzPQ5KJSUsuRP6hvuxZP5Ml+X/A4sEv52So2LpIqF1uN+er/v/QPvwmjMPqVI61IRCq+u4htRw3i9tQJjHwxwvaqiEms9ifZJH1B351u3sm/D8Hs6nJYb7/pnP1B5Ot0BnTtY09A1ICCFEIWgBEkIIUQhagIQQQhTCzGpA7ZNtq9TOxU7Xn/Dj5fV5P2a6be8kjunqQWZm9UtgcQBKzTW04/HYoHacy8Lu6KXUCZh3YkPf4tvVTqi7sG5ZzCrALMz1cQn1JMZ2p1+7KUFuyGRs7DtVX41jS+lNrs5DTYf3Zl+sDZe6l9s/9SW+n7mqP5Z6mdbjvnjCnIuYnQN1lpRVNfWMmGVCWBMtri8NcH6slpiZr/OkxsWxpPoeR56D9fBS92bpOM7pwHnf87gZ9TzWequP/c9Zs+LrzvzZd+u78Rg1HlKHrhnk9pT8sbWq0Mi9vqa/gHZ8GJP7b+00IYQQ4uKiBUgIIUQhaAESQghRCDOrAZ1+9KwNzhcbOnXaPzYe+7k9l3YmAcfanB8/3b3kxzCzph+PzbrIgYGFN4O/peW9kwY8etY
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.imshow(y_real[0], cmap='RdYlGn_r')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 79,
|
|||
|
"id": "f662e7a6-2edd-4602-a33b-cda8ac2fae00",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(96, 96)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 79,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"(y_real[0] * mask[0] + out[0][0] * (1-mask[0])).shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 80,
|
|||
|
"id": "78b86ce5-296a-498a-89d0-8ee6d9cdac27",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"d = y_real[0] * mask[0] + out[0][0] * (1-mask[0])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 81,
|
|||
|
"id": "756f6f85-6b07-4295-b8e4-a6dc77752a56",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.image.AxesImage at 0x7f65196220a0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 81,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGgCAYAAADsNrNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7FklEQVR4nO3de3hU9bkv8Ddzn2SSyY3MJJBAQCxXBbkZoBeV1q24K5Vjazfu4mXXYxuqmHOs4hZ9RCHano1sfVA3VqnueqWnqNV64YSKRZFLNAhSwi0aSJhcSDKTy2QmmVnnj6Qz631DggMJv0ny/TxPnmfeWTNrfllJ+LHWu37vm6BpmkYAAADnmUH1AAAAYHjCBAQAAEpgAgIAACUwAQEAgBKYgAAAQAlMQAAAoAQmIAAAUAITEAAAKIEJCAAAlMAEBAAASgzYBLR+/XoaM2YM2Ww2mjNnDu3atWugPgoAAAahhIGoBffaa6/Rz372M3rmmWdozpw5tG7dOtq0aROVl5dTVlZWn+8Nh8NUXV1NycnJlJCQ0N9DAwCAAaZpGjU3N1NOTg4ZDH2c52gDYPbs2VphYWEkDoVCWk5OjlZcXHzG9x4/flwjInzhC1/4wtcg/zp+/Hif/96bqJ8Fg0EqLS2lFStWRJ4zGAy0YMEC2rFjR4/XBwIBCgQCkVj7xwnZjdOILMb+Hh6cxjU/uIDFX3sDLM5MNLP4cIOfxSf2nhyYgRHRjILRLA6JE/Y0O/8V/trbHnl87LPqARsXAPQhGCL6QxklJyf3+bJ+n4Dq6+spFAqRy+Viz7tcLjp48GCP1xcXF9NDDz3Uc0cWI5Gl34cHp2FOtLDYGAyz2CQmIIO/k+9gAH9ORvHZJCYgk5iADMFQNMDvD4BSZ0qjKP8LXbFiBRUVFUVin89Hubm5Ckc0/LR0hFh85Vgniz8+0cLiqSOSWFw5MMMiIqJdf/uKxen5aSw+GuKTZVOldwBHAwD9qd8noMzMTDIajVRTU8Oer6mpIbfb3eP1VquVrFZrfw8DAADiXL/fhm2xWGjGjBlUUlISeS4cDlNJSQkVFBT098cBAMAgNSCX4IqKimjp0qU0c+ZMmj17Nq1bt45aW1vp5ptvHoiPAwCAQWhAJqCf/OQnVFdXRw888AB5PB6aNm0avffeez1uTID44DDzuw0PNfC74Cp9PLabOgZ8TL1pqGhU9tkA0L8G7CaEZcuW0bJlywZq9wAAMMihFhwAACiBCQgAAJRQvg4I1Cv18HU+V1+QzuLrvpXB4nWv7BvwMQHA0IczIAAAUAITEAAAKBG3l+Auu2xspAbZe9MvZtuMv/mTiiENWZWiaOczKOIJAOcBzoAAAEAJTEAAAKAEJiAAAFAibnNAN0xIIbujq0/NvoXPsW2hzT9jsXHdn8/buAAAoH/gDAgAAJTABAQAAEpgAgIAACXiNgc0K2ssOVJsREQ05s4Gtu21iS+y+FDD/4w8vnDF6wM/OAAAOGc4AwIAACUwAQEAgBKYgAAAQIm4zQGNdU6llJQkIiLSJlWxbZMmfM7icUdPRR6Hbihg24yv7higEQIAwLnAGRAAACiBCQgAAJTABAQAAErEbQ7oRMsRSjbYiYgoY9Neti0508piz0MfRB5nb7hV7Ak5IACAeIQzIAAAUAITEAAAKIEJCAAAlIjbHJCnrZaajV25nrbtJ9g2S3efoH8YsWh85HH51ev4juakDsTwAADgHOEMCAAAlMAEBAAASmACAgAAJeI2B+QL+ikUDBMR0axfz2HbOo40sjjc0B55PPbfLmLbjrx0gMUXXOToz2ECAMBZwhkQAAAogQkIAACUiNtLcDOyplFKSiIRESVMP8y2+d54h8UJhoTI44wV3+fbXj04QCOEbyptTGrksTWZl1EyGvn/garKTp6PIQFAHMAZEAAAKIEJCAAAlMAEBAAASsRtDui5L7eRrbvkzr3fuoJtG/Ef/8Li9g3vRR7XPfw+25b/xzv4jh/8XT+OcnhIdvNb17WwxmKDif8/JhQMsbizvTPyOGNEEttmMfD3dk7OYnHNl7WxDRYABg2cAQEAgBKYgAAAQAlMQAAAoETc5oBGOoxkdxiJiEirFGt56nkpnn2/nBV5PDVjBtv2pXc/ixvXFbH4RMsxFk+9/42zGe6gljVxBIttdnOfr78gzc7iquYAi73+DhZbzMboY5HzCWk8n5Th4OuExn8vn8XbP6zoc2wAMHjgDAgAAJTABAQAAEpgAgIAACXiNgdU2xYmm6GrHUPCt6awbUdGfsXiBq8n+jjg4dvaW1icaOLbk8xoz5AiWpy36NbtEBGlJ/LtMufjsBhZ3Ci2ky4H1NAWZJtkLbh0O/+VtJv59sKf8N+F9a/xHB8ADB44AwIAACUwAQEAgBKYgAAAQIm4zQEd9XaSpbNrflxU8gbb9u1RiSw+2RqtPfbHwzvYth9fmMnipoCPxcd8vNbYmlsvZrHTwufo74wcHXn8l6/4mpR7freXBiOnlf8azM91sjjZ0vf/U5raee23qVm83luLrjbcrurmPvdlETkhr9j30cZ2AoChAWdAAACgBCYgAABQAhMQAAAoEbc5oH01rWRM7Fozckk2X6vzRR1fZ5KhWzvi7wyzbXIdUEsHzyEcbuL7agny2mSuxAQWbzxwNPJ47cv7eh3/YFK6/WsW5119IYtdibw+2yVZPG4UeZr/e5jn2Wa5ozmhfCdfU7StkueEZG04ucbovQ+OEAAMDTgDAgAAJTABAQCAEpiAAABAibjNAVXUNpPB1tWX5uuGNrYt0caHHdKlDewmPqeOdPDXGhJ4LbJMO88xdIR5DundilYWe1r4+4eiEE/DUHULrw3XGOAvqGvlx2SEqB2XaI7m0bzivdPcPL/n7+D5JDmW7IvdLD65l9f2A4DBA2dAAACgBCYgAABQIm4vwWlhjbRw1/WX1CR+SceYwG+N1t+q6w3wy0VecVu1vIQ2NpXfUvxlPb/c5+/gl+T0ZWV++/NpbNvdz5bRUPDWu4dYvPa2aSyu9/NjMimdl0baV89bcluN0Z/XOCf/lWvu4D+fL+rEJbgw354kbsvOn5HDX697eeVn1QQA8QtnQAAAoAQmIAAAUAITEAAAKBG3OaBrprrJ0p37keVZXIk8D/CXo02Rx/K1iSaeL/peLs9XZNr5HDzOmczig408p1SQHc1HJZl4bmqo+ryWlyuamNH3931hGv+1GpMSfX15I99XhZcf3wY/jyuaeOkki4H/vE618f2FdaWYMsals20JBv67UH/4VI+xA8D5gzMgAABQAhMQAAAoEdMEVFxcTLNmzaLk5GTKysqiRYsWUXl5OXtNe3s7FRYWUkZGBjkcDlq8eDHV1NT066ABAGDwiykHtG3bNiosLKRZs2ZRZ2cn3XffffSDH/yADhw4QElJXSX377rrLnrnnXdo06ZN5HQ6admyZXTdddfRxx9/HNPAPjjWQAZ7Vymewhkutm1HFV+r4w1E144ExDogua5nz0m+huWCdDuLRyfzQ6LP+RARVbVEP2tMCv+soUqWw7nvudhaj2/6X7Mjj7OTeP5O5oBk+4UZohVHVTNfxxUUpZNa2qP7s4v1Yy2N/m84YgA4H2KagN577z0W//73v6esrCwqLS2l73znO+T1eum5556jl19+mS6//HIiItq4cSNNnDiRPv30U7r00kv7b+QAADConVMOyOv1EhFRenrX3UalpaXU0dFBCxYsiLxmwoQJlJeXRzt27DjtPgKBAPl8PvYFAABD31lPQOFwmJYvX07z5s2jKVOmEBGRx+Mhi8VCqamp7LUul4s8ntNXLS4uLian0xn5ys3NPdshAQDAIHLW64AKCwtp//79tH379nMawIoVK6ioqCgS+3w+ys3NpUtHpZC5u6z/9hO8JYLME/h17QDCoiX33JwkFsvaY8suvpjFdX6+NiTVyt9/47o/9vq9DFUvv/n3c3r/9f+xK/L4n35wAdtW18brxjmtff9Kylp/si5gvi6np6/bR0TUKtYBAYBaZzUBLVu2jN5++2366KOPaNSoUZHn3W43BYNBampqYmdBNTU15Ha7T7MnIqv
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.imshow(out[0][0] * (1-mask[0]), cmap='RdYlGn_r')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "04034e72-02be-42ec-9aca-a5eb28141453",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|