724 lines
23 KiB
Plaintext
724 lines
23 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "3a2c33ed-8f78-4ce4-b5cd-7b7ffc5c8273",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Using device: cuda\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import os\n",
|
||
"import numpy as np\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.optim as optim\n",
|
||
"from torch.utils.data import DataLoader, Dataset\n",
|
||
"from PIL import Image\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" \n",
|
||
"\n",
|
||
"\n",
|
||
"# 设置CUDA设备\n",
|
||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"print(f\"Using device: {device}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "85459fd1-6835-41cd-b645-553611c358e8",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Maximum pixel value in the dataset: 107.49169921875\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"max_pixel_value = 107.49169921875\n",
|
||
"\n",
|
||
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "3fc0918e-103c-40a3-93bc-6171e934a7e4",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"checkpoint before Generator is OK\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"class NO2Dataset(Dataset):\n",
|
||
" \n",
|
||
" def __init__(self, image_dir, mask_dir):\n",
|
||
" \n",
|
||
" self.image_dir = image_dir\n",
|
||
" self.mask_dir = mask_dir\n",
|
||
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
||
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" \n",
|
||
" return len(self.image_filenames)\n",
|
||
" \n",
|
||
" def __getitem__(self, idx):\n",
|
||
" \n",
|
||
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
||
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
||
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
||
"\n",
|
||
" # 加载图像数据 (.npy 文件)\n",
|
||
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 加载掩码数据 (.jpg 文件)\n",
|
||
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
||
"\n",
|
||
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
||
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
||
"\n",
|
||
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
||
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 应用掩码\n",
|
||
" masked_image = image.copy()\n",
|
||
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
||
"\n",
|
||
" # cGAN的输入和目标\n",
|
||
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
||
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 转换形状为 (channels, height, width)\n",
|
||
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
"\n",
|
||
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
||
"\n",
|
||
"# 实例化数据集和数据加载器\n",
|
||
"train_dir = './out_mat/96/train/'\n",
|
||
"valid_dir = './out_mat/96/valid/'\n",
|
||
"test_dir = './out_mat/96/test/'\n",
|
||
"mask_dir = './out_mat/96/mask/20/'\n",
|
||
"\n",
|
||
"print(f\"checkpoint before Generator is OK\")\n",
|
||
"\n",
|
||
"dataset = NO2Dataset(train_dir, mask_dir)\n",
|
||
"train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
||
"\n",
|
||
"validset = NO2Dataset(valid_dir, mask_dir)\n",
|
||
"val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)\n",
|
||
"\n",
|
||
"testset = NO2Dataset(test_dir, mask_dir)\n",
|
||
"test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "a60b7019-f231-4ccb-9195-c459f3a1521d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Generator is on: cpu\n",
|
||
"Discriminator is on: cpu\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 生成器模型\n",
|
||
"class Generator(nn.Module):\n",
|
||
" \n",
|
||
" def __init__(self):\n",
|
||
" super(Generator, self).__init__()\n",
|
||
" self.encoder = nn.Sequential(\n",
|
||
" nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||
" nn.BatchNorm2d(64),\n",
|
||
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||
" nn.BatchNorm2d(128),\n",
|
||
" )\n",
|
||
" self.decoder = nn.Sequential(\n",
|
||
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.ReLU(inplace=True),\n",
|
||
" nn.BatchNorm2d(64),\n",
|
||
" nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.Tanh(),\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x, mask):\n",
|
||
" x_encoded = self.encoder(x)\n",
|
||
" x_decoded = self.decoder(x_encoded)\n",
|
||
"\n",
|
||
"# x_decoded = (x_decoded + 1) / 2\n",
|
||
"\n",
|
||
"# x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n",
|
||
" return x_decoded\n",
|
||
"\n",
|
||
"# 判别器模型\n",
|
||
"class Discriminator(nn.Module):\n",
|
||
" \n",
|
||
" def __init__(self):\n",
|
||
" super(Discriminator, self).__init__()\n",
|
||
" self.model = nn.Sequential(\n",
|
||
" nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||
" nn.BatchNorm2d(64),\n",
|
||
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||
" nn.BatchNorm2d(128),\n",
|
||
" nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n",
|
||
" nn.Sigmoid(),\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" return self.model(x)\n",
|
||
"\n",
|
||
"# 将模型加载到GPU\n",
|
||
"generator = Generator().to(device)\n",
|
||
"discriminator = Discriminator().to(device)\n",
|
||
"\n",
|
||
"# 定义优化器和损失函数\n",
|
||
"optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
||
"optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
||
"adversarial_loss = nn.BCELoss().to(device)\n",
|
||
"\n",
|
||
"# 确认模型是否在GPU上\n",
|
||
"print(f\"Generator is on: {next(generator.parameters()).device}\")\n",
|
||
"print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "fd2f0816-7301-4381-99e4-05905ce8b093",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def masked_mse_loss(preds, target, mask):\n",
|
||
" loss = (preds - target) ** 2\n",
|
||
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
||
" loss = (loss * (1-mask)).sum() / (1-mask).sum() # 只计算被mask的像素点的损失\n",
|
||
" return loss"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "edf26d05-7fd6-404b-8f9d-3d2054593f7b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<All keys matched successfully>"
|
||
]
|
||
},
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"generator.load_state_dict(torch.load('./models/GAN/generator-1d.pth'))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "645dd325-fc70-4234-8279-bc8cbc4c5dde",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 开始训练\n",
|
||
"epochs = 100\n",
|
||
"for epoch in range(epochs):\n",
|
||
" for i, (X, y, mask) in enumerate(train_loader):\n",
|
||
" # 将数据移到 GPU 上\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" # print(f\"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}\") #checkpoint\n",
|
||
" \n",
|
||
" valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n",
|
||
" fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n",
|
||
"\n",
|
||
" # 生成器生成图像\n",
|
||
" optimizer_G.zero_grad()\n",
|
||
" generated_images = generator(X, mask)\n",
|
||
" g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * masked_mse_loss(\n",
|
||
" generated_images, y, mask)\n",
|
||
" g_loss.backward()\n",
|
||
" optimizer_G.step()\n",
|
||
"\n",
|
||
" # 判别器训练\n",
|
||
" optimizer_D.zero_grad()\n",
|
||
" real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n",
|
||
" fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n",
|
||
" d_loss = 0.5 * (real_loss + fake_loss)\n",
|
||
" d_loss.backward()\n",
|
||
" optimizer_D.step()\n",
|
||
"\n",
|
||
" print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n",
|
||
"\n",
|
||
"# 保存训练好的模型\n",
|
||
"torch.save(generator.state_dict(), './models/GAN/generator-1d.pth')\n",
|
||
"torch.save(discriminator.state_dict(), './models/GAN/discriminator-1d.pth')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "9ab6849f-740a-49e4-9afd-23aafdc16725",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "8997686a-5812-4a92-972c-11376dfc1686",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def cal_ioa(y_true, y_pred):\n",
|
||
" # 计算平均值\n",
|
||
" mean_observed = np.mean(y_true)\n",
|
||
" mean_predicted = np.mean(y_pred)\n",
|
||
"\n",
|
||
" # 计算IoA\n",
|
||
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
||
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
||
" IoA = 1 - (numerator / denominator)\n",
|
||
"\n",
|
||
" return IoA"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "78db8f70-3cdb-444f-8bb5-49126957a0b6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"eva_list = list()\n",
|
||
"device = 'cpu'\n",
|
||
"generator = generator.to(device)\n",
|
||
"with torch.no_grad():\n",
|
||
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
||
" reconstructed = generator(X, mask)\n",
|
||
" rev_data = torch.squeeze(y * max_pixel_value, dim=1)\n",
|
||
" rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)\n",
|
||
" # todo: 这里需要只评估修补出来的模块\n",
|
||
" data_label = rev_data * mask_rev\n",
|
||
" data_label = data_label[mask_rev==1]\n",
|
||
" recon_no2 = rev_recon * mask_rev\n",
|
||
" recon_no2 = recon_no2[mask_rev==1]\n",
|
||
" y_true = rev_data.flatten()\n",
|
||
" y_pred = rev_recon.flatten()\n",
|
||
" mae = mean_absolute_error(y_true, y_pred)\n",
|
||
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
|
||
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
|
||
" r2 = r2_score(y_true, y_pred)\n",
|
||
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
||
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "ffc83338-9b0f-4934-9cba-406a5e1fb0ca",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "33eb7a86-242a-4488-9d9b-93548e72c98b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>mae</th>\n",
|
||
" <th>rmse</th>\n",
|
||
" <th>mape</th>\n",
|
||
" <th>r2</th>\n",
|
||
" <th>ioa</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>count</th>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>mean</th>\n",
|
||
" <td>1.685609</td>\n",
|
||
" <td>2.824579</td>\n",
|
||
" <td>0.223852</td>\n",
|
||
" <td>0.807483</td>\n",
|
||
" <td>0.894409</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>std</th>\n",
|
||
" <td>0.520285</td>\n",
|
||
" <td>0.613299</td>\n",
|
||
" <td>0.066827</td>\n",
|
||
" <td>0.107566</td>\n",
|
||
" <td>0.024969</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>min</th>\n",
|
||
" <td>1.108756</td>\n",
|
||
" <td>2.040964</td>\n",
|
||
" <td>0.143461</td>\n",
|
||
" <td>0.336193</td>\n",
|
||
" <td>0.812887</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25%</th>\n",
|
||
" <td>1.338143</td>\n",
|
||
" <td>2.462648</td>\n",
|
||
" <td>0.176170</td>\n",
|
||
" <td>0.780906</td>\n",
|
||
" <td>0.883027</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>50%</th>\n",
|
||
" <td>1.509821</td>\n",
|
||
" <td>2.608227</td>\n",
|
||
" <td>0.206274</td>\n",
|
||
" <td>0.850417</td>\n",
|
||
" <td>0.900165</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>75%</th>\n",
|
||
" <td>1.963103</td>\n",
|
||
" <td>3.067560</td>\n",
|
||
" <td>0.257667</td>\n",
|
||
" <td>0.866705</td>\n",
|
||
" <td>0.910917</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>max</th>\n",
|
||
" <td>3.729434</td>\n",
|
||
" <td>5.363288</td>\n",
|
||
" <td>0.461465</td>\n",
|
||
" <td>0.912240</td>\n",
|
||
" <td>0.935183</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" mae rmse mape r2 ioa\n",
|
||
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
||
"mean 1.685609 2.824579 0.223852 0.807483 0.894409\n",
|
||
"std 0.520285 0.613299 0.066827 0.107566 0.024969\n",
|
||
"min 1.108756 2.040964 0.143461 0.336193 0.812887\n",
|
||
"25% 1.338143 2.462648 0.176170 0.780906 0.883027\n",
|
||
"50% 1.509821 2.608227 0.206274 0.850417 0.900165\n",
|
||
"75% 1.963103 3.067560 0.257667 0.866705 0.910917\n",
|
||
"max 3.729434 5.363288 0.461465 0.912240 0.935183"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 106,
|
||
"id": "3844152b-c853-4311-a0a4-f57275ea78fc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"rst = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape', ascending=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 140,
|
||
"id": "8f98d947-5a20-49f5-84a1-5ae5a0f4693e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>mae</th>\n",
|
||
" <th>rmse</th>\n",
|
||
" <th>mape</th>\n",
|
||
" <th>r2</th>\n",
|
||
" <th>ioa</th>\n",
|
||
" <th>r</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>3523</th>\n",
|
||
" <td>1.834664</td>\n",
|
||
" <td>2.579296</td>\n",
|
||
" <td>0.103911</td>\n",
|
||
" <td>0.855056</td>\n",
|
||
" <td>0.960841</td>\n",
|
||
" <td>0.931334</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3544</th>\n",
|
||
" <td>1.500194</td>\n",
|
||
" <td>1.962885</td>\n",
|
||
" <td>0.106816</td>\n",
|
||
" <td>0.849688</td>\n",
|
||
" <td>0.960970</td>\n",
|
||
" <td>0.924187</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1952</th>\n",
|
||
" <td>1.786639</td>\n",
|
||
" <td>2.290560</td>\n",
|
||
" <td>0.109383</td>\n",
|
||
" <td>0.704122</td>\n",
|
||
" <td>0.928829</td>\n",
|
||
" <td>0.869446</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>602</th>\n",
|
||
" <td>2.222957</td>\n",
|
||
" <td>2.934734</td>\n",
|
||
" <td>0.112751</td>\n",
|
||
" <td>0.735178</td>\n",
|
||
" <td>0.933639</td>\n",
|
||
" <td>0.877028</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3531</th>\n",
|
||
" <td>2.093165</td>\n",
|
||
" <td>2.726698</td>\n",
|
||
" <td>0.115755</td>\n",
|
||
" <td>0.760530</td>\n",
|
||
" <td>0.937662</td>\n",
|
||
" <td>0.889606</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1114</th>\n",
|
||
" <td>1.951748</td>\n",
|
||
" <td>2.591448</td>\n",
|
||
" <td>0.116578</td>\n",
|
||
" <td>0.696970</td>\n",
|
||
" <td>0.914501</td>\n",
|
||
" <td>0.843026</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1979</th>\n",
|
||
" <td>2.083001</td>\n",
|
||
" <td>2.686231</td>\n",
|
||
" <td>0.116762</td>\n",
|
||
" <td>0.597512</td>\n",
|
||
" <td>0.886877</td>\n",
|
||
" <td>0.791842</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2568</th>\n",
|
||
" <td>2.630587</td>\n",
|
||
" <td>3.636890</td>\n",
|
||
" <td>0.117044</td>\n",
|
||
" <td>0.491952</td>\n",
|
||
" <td>0.893928</td>\n",
|
||
" <td>0.833221</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" mae rmse mape r2 ioa r\n",
|
||
"3523 1.834664 2.579296 0.103911 0.855056 0.960841 0.931334\n",
|
||
"3544 1.500194 1.962885 0.106816 0.849688 0.960970 0.924187\n",
|
||
"1952 1.786639 2.290560 0.109383 0.704122 0.928829 0.869446\n",
|
||
"602 2.222957 2.934734 0.112751 0.735178 0.933639 0.877028\n",
|
||
"3531 2.093165 2.726698 0.115755 0.760530 0.937662 0.889606\n",
|
||
"1114 1.951748 2.591448 0.116578 0.696970 0.914501 0.843026\n",
|
||
"1979 2.083001 2.686231 0.116762 0.597512 0.886877 0.791842\n",
|
||
"2568 2.630587 3.636890 0.117044 0.491952 0.893928 0.833221"
|
||
]
|
||
},
|
||
"execution_count": 140,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"rst.head(8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 141,
|
||
"id": "a72b4ec5-b72d-43d0-8178-954f8edcb0dd",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'1114', '1952', '2568', '3523', '602'}"
|
||
]
|
||
},
|
||
"execution_count": 141,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
|
||
"find_ex"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 159,
|
||
"id": "819950e1-d16d-42e9-a0f9-57878ebc8f89",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 0 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"for j in find_ex:\n",
|
||
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
|
||
" pred = np.load(f'./test_img/{j}-gan-recom.npy')[0]\n",
|
||
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
|
||
" plt.imshow(ori, cmap='RdYlGn_r')\n",
|
||
" plt.gca().axis('off')\n",
|
||
" plt.savefig(f'./test_img/out_fig/{j}-truth.png', bbox_inches='tight')\n",
|
||
" plt.clf()\n",
|
||
" \n",
|
||
" plt.imshow(mask, cmap='gray')\n",
|
||
" plt.gca().axis('off')\n",
|
||
" plt.savefig(f'./test_img/out_fig/{j}-mask.png', bbox_inches='tight')\n",
|
||
" plt.clf()\n",
|
||
" \n",
|
||
" mask_cp = np.where((1-mask) == 0, np.nan, (1-mask))\n",
|
||
" plt.imshow(ori * mask_cp, cmap='RdYlGn_r')\n",
|
||
" plt.gca().axis('off')\n",
|
||
" plt.savefig(f'./test_img/out_fig/{j}-masked_ori.png', bbox_inches='tight')\n",
|
||
" plt.clf()\n",
|
||
" \n",
|
||
" out = ori * mask + pred * (1 - mask)\n",
|
||
" plt.imshow(out, cmap='RdYlGn_r')\n",
|
||
" plt.gca().axis('off')\n",
|
||
" plt.savefig(f'./test_img/out_fig/{j}-gan_out.png', bbox_inches='tight')\n",
|
||
" plt.clf()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "2926f4b3-0a15-48ae-bdaa-4e290e11b49d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.16"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|