MAE_ATMO/torch_MAE_1d_final_30.ipynb

944 lines
187 KiB
Plaintext
Raw Permalink Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum pixel value in the dataset: 107.49169921875\n"
]
}
],
"source": [
"max_pixel_value = 107.49169921875\n",
"\n",
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" print(len(self.mask_filenames))\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/30/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "41da7319-9795-441d-bde8-8cf390365099",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3849\n",
"3849\n",
"3849\n"
]
}
],
"source": [
"dataset = NO2Dataset(image_dir, mask_dir)\n",
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "645114e8-65a4-4867-b3fe-23395288e855",
"metadata": {},
"outputs": [],
"source": [
"class Conv(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
" super(Conv, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
" bias=False):\n",
" super(ConvBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
" norm_layer(out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "31ecf247-e98b-4977-a145-782914a042bd",
"metadata": {},
"outputs": [],
"source": [
"class SeparableBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
" super(SeparableBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
" # 分离卷积,仅调整空间信息\n",
" norm_layer(in_channels), # 对输入通道进行归一化\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
" nn.ReLU6()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" # 如果输入和输出通道不一致,进行降采样操作\n",
" self.downsample = downsample\n",
" if in_channels != out_channels or stride != 1:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(out_channels)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttentionBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
" super(MultiHeadAttentionBlock, self).__init__()\n",
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
" B, C, H, W = x.shape\n",
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
"\n",
" # Apply multihead attention\n",
" attn_output, _ = self.attention(x, x, x)\n",
"\n",
" # Apply normalization and dropout\n",
" attn_output = self.norm(attn_output)\n",
" attn_output = self.dropout(attn_output)\n",
"\n",
" # Reshape back to (B, C, H, W)\n",
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
"\n",
" return attn_output"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
"metadata": {},
"outputs": [],
"source": [
"class SpatialAttentionBlock(nn.Module):\n",
" def __init__(self):\n",
" super(SpatialAttentionBlock, self).__init__()\n",
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
"\n",
" def forward(self, x): #(B, 64, H, W)\n",
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
" return x * out #(B, C, H, W)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
"metadata": {},
"outputs": [],
"source": [
"class DecoderAttentionBlock(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(DecoderAttentionBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
" self.spatial_attention = SpatialAttentionBlock()\n",
"\n",
" def forward(self, x):\n",
" # 通道注意力\n",
" b, c, h, w = x.size()\n",
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
"\n",
" avg_out = self.conv1(avg_pool)\n",
" max_out = self.conv1(max_pool)\n",
"\n",
" out = avg_out + max_out\n",
" out = torch.sigmoid(self.conv2(out))\n",
"\n",
" # 添加空间注意力\n",
" out = x * out\n",
" out = self.spatial_attention(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" Conv(1, 32, kernel_size=3, stride=2),\n",
" \n",
" nn.ReLU(),\n",
" \n",
" SEBlock(32,32),\n",
" \n",
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
" \n",
" ResidualBlock(64,64),\n",
" \n",
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
" \n",
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
" \n",
" SEBlock(128, 128)\n",
" )\n",
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(32),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(16),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e9c804e0-6f5c-40a7-aba7-a03a496cf427",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(X)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" # loss = criterion(reconstructed, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" reconstructed = model(X)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(y))\n",
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 150\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, dataloader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f43384faca0>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5kUlEQVR4nO3deXhU9d3//9c5M8kkgSwsJgEMSysgCAZkM+H+KWoUASloSy3lLmAr/dqCBam20laqeGlsvVGsG+X2Vm5bEZcCeuNWBFHBiIDEilasioBKAmpJSIAkM+fz+2OWTJRgJsscYJ6P6zoXZOYs7/kQJq98zjnvsYwxRgAAAC6x3S4AAAAkNsIIAABwFWEEAAC4ijACAABcRRgBAACuIowAAABXEUYAAICrCCMAAMBVXrcLaArHcfTZZ58pPT1dlmW5XQ4AAGgCY4wOHjyorl27yrYbn/84IcLIZ599pry8PLfLAAAAzbBnzx6deuqpjT5/QoSR9PR0ScEXk5GR4XI1AACgKSorK5WXlxf5Od6YEyKMhE/NZGRkEEYAADjBfNMlFlzACgAAXEUYAQAAriKMAAAAV50Q14wAAE4+gUBAdXV1bpeBFvB4PPJ6vS1uu0EYAQDEXVVVlT755BMZY9wuBS2UlpamLl26KDk5udn7IIwAAOIqEAjok08+UVpamk455RSaWZ6gjDGqra3V/v37tXPnTvXu3fuYjc2OhTACAIiruro6GWN0yimnKDU11e1y0AKpqalKSkrSrl27VFtbq5SUlGbthwtYAQCuYEbk5NDc2ZAG+2iFOgAAAJqNMAIAAFxFGAEAIM569uypRYsWtcq+1q9fL8uydODAgVbZnxu4gBUAgCYYNWqUBg0a1CohYvPmzWrXrl3LizpJJHQYeeDVj/TJvw9r8vDu6pt77E8UBADgWIwxCgQC8nq/+UfrKaecEoeKThwJfZrmmbf3aulrH2vXF9VulwIACcsYo0O1fleWpjZdmz59ul5++WXdddddsixLlmVp6dKlsixLzz33nIYMGSKfz6cNGzboww8/1IQJE5STk6P27dtr2LBhevHFFxvs76unaSzL0gMPPKBLL71UaWlp6t27t55++ulmj+nf/vY3nXHGGfL5fOrZs6cWLlzY4Pn77rtPvXv3VkpKinJycvS9730v8tyTTz6pgQMHKjU1VZ06dVJRUZGqq9v252RCz4x47eBtZQ4dAAHANYfrAuo//wVXjv3ugtFKS/7mH4V33XWX3n//fQ0YMEALFiyQJL3zzjuSpOuvv17/9V//pW9961vq0KGD9uzZo7Fjx+qWW26Rz+fTww8/rPHjx2vHjh3q3r17o8e46aab9Mc//lG333677r77bk2ZMkW7du1Sx44dY3pNW7du1fe//33deOONuvzyy/Xaa6/p5z//uTp16qTp06dry5Yt+sUvfqG//OUvKiws1JdffqlXX31VkrR3715NnjxZf/zjH3XppZfq4MGDevXVV9u8U25ChxE7dI+73yGMAAAal5mZqeTkZKWlpSk3N1eS9N5770mSFixYoAsvvDCybseOHZWfnx/5+uabb9bKlSv19NNPa9asWY0eY/r06Zo8ebIk6dZbb9Wf/vQnvfHGG7r44otjqvWOO+7QBRdcoBtuuEGS1KdPH7377ru6/fbbNX36dO3evVvt2rXTJZdcovT0dPXo0UODBw+WFAwjfr9fl112mXr06CFJGjhwYEzHb46EDiNeTzCMBAgjAOCa1CSP3l0w2rVjt9TQoUMbfF1VVaUbb7xRzzzzTOSH++HDh7V79+5j7ufMM8+M/L1du3bKyMjQvn37Yq7nn//8pyZMmNDgsZEjR2rRokUKBAK68MIL1aNHD33rW9/SxRdfrIsvvjhyeig/P18XXHCBBg4cqNGjR+uiiy7S9773PXXo0CHmOmKR0NeMhGdGCCMA4B7LspSW7HVlaY0usF+9K+baa6/VypUrdeutt+rVV19VaWmpBg4cqNra2mPuJykp6Wvj4jhOi+v7qvT0dL355pt69NFH1aVLF82fP1/5+fk6cOCAPB6P1qxZo+eee079+/fX3Xffrb59+2rnzp2tXke0hA4j4WtGCCMAgG+SnJysQCDwjett3LhR06dP16WXXqqBAwcqNzdXH3/8cdsXGNKvXz9t3LjxazX16dNHHk9wJsjr9aqoqEh//OMf9Y9//EMff/yx1q1bJykYgkaOHKmbbrpJ27ZtU3JyslauXNmmNSf0aRoPYQQA0EQ9e/bUpk2b9PHHH6t9+/aNzlr07t1bK1as0Pjx42VZlm644YY2meFozC9/+UsNGzZMN998sy6//HKVlJTonnvu0X333SdJWr16tT766COdc8456tChg5599lk5jqO+fftq06ZNWrt2rS666CJlZ2dr06ZN2r9/v/r169emNSf0zEgkjHA3DQDgG1x77bXyeDzq37+/TjnllEavAbnjjjvUoUMHFRYWavz48Ro9erTOOuusuNV51lln6fHHH9fy5cs1YMAAzZ8/XwsWLND06dMlSVlZWVqxYoXOP/989evXT4sXL9ajjz6qM844QxkZGXrllVc0duxY9enTR7/73e+0cOFCjRkzpk1rtkxb36/TCiorK5WZmamKigplZGS02n5//shWPft2mRZMOENTC3q22n4BAI07cuSIdu7cqV69ejX7I+dx/DjWv2dTf34n+MxI8OX7A8d9HgMA4KSV0GGEpmcAgOPdVVddpfbt2x91ueqqq9wur1Uk9AWsND0DABzvFixYoGuvvfaoz7XmpQtuSugwwq29AIDjXXZ2trKzs90uo00l9GkamzACAIDrEjqMMDMCAID7EjqM0PQMAAD3xRRG7r//fp155pnKyMhQRkaGCgoK9Nxzzx1zmyeeeEKnn366UlJSNHDgQD377LMtKrg10fQMAAD3xRRGTj31VN12223aunWrtmzZovPPP18TJkzQO++8c9T1X3vtNU2ePFk/+clPtG3bNk2cOFETJ07U9u3bW6X4luI0DQAA7ospjIwfP15jx45V79691adPH91yyy1q3769Xn/99aOuf9ddd+niiy/Wddddp379+unmm2/WWWedpXvuuadVim+p8AWsND0DAMRDz549tWjRoiata1mWVq1a1ab1HC+afc1IIBDQ8uXLVV1drYKCgqOuU1JSoqKiogaPjR49WiUlJc09bKui6RkAAO6Luc/I22+/rYKCAh05ckTt27fXypUr1b9//6OuW1ZWppycnAaP5eTkqKys7JjHqKmpUU1NTeTrysrKWMtskvqmZ/H7NEUAANBQzDMjffv2VWlpqTZt2qSf/exnmjZtmt59991WLaq4uFiZmZmRJS8vr1X3H1Z/zUib7B4A0BTGSLXV7iwxzIwvWbJEXbt2lfOVX2AnTJigH//4x/rwww81YcIE5eTkqH379ho2bJhefPHFVhumt99+W+eff75SU1PVqVMn/fSnP1VVVVXk+fXr12v48OFq166dsrKyNHLkSO3atUuS9NZbb+m8885Tenq6MjIyNGTIEG3ZsqXVamupmGdGkpOTddppp0mShgwZos2bN+uuu+7Sn//856+tm5ubq/Ly8gaPlZeXKzc395jHmDdvnubOnRv5urKysk0CSX3TM9IIALim7pB0a1d3jv2bz6Tkdk1addKkSbr66qv10ksv6YILLpAkffnll3r++ef17LPPqqqqSmPHjtUtt9win8+nhx9+WOPHj9eOHTvUvXv3FpVZXV2t0aNHq6CgQJs3b9a+fft05ZVXatasWVq6dKn8fr8mTpyoGTNm6NFHH1Vtba3eeOMNWaEzAFOmTNHgwYN1//33y+PxqLS0VElJSS2qqTW1uB284zgNTqlEKygo0Nq1azVnzpzIY2vWrGn0GpMwn88nn8/X0tK+ETMjAICm6tChg8aMGaNly5ZFwsiTTz6pzp0767zzzpNt28rPz4+sf/PNN2vlypV6+umnNWvWrBYde9myZTpy5IgefvhhtWsXDE/33HOPxo8frz/84Q9KSkpSRUWFLrnkEn3729+WJPXr1y+y/e7du3Xdddfp9NNPlyT
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "290edd23-b3ce-474d-b654-2e1096be9866",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, './models/MAE/final_30.pt')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "9d0f3b92-58c2-4794-ae98-7e10546dfb0f",
"metadata": {},
"outputs": [],
"source": [
"model = torch.load('./models/MAE/final_30.pt')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "59997827-2df9-4593-92b1-4fdc7b6307b4",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5b01834-ca18-4ec3-bc9d-64382d0fab34",
"metadata": {},
"outputs": [],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "984650d0-880c-476f-9b7d-e47e8d0fea23",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"best_mape = 1\n",
"best_img = None\n",
"best_mask = None\n",
"best_recov = None\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
" if mape < best_mape:\n",
" best_recov = rev_recon[i][0].numpy()\n",
" best_mask = used_mask.numpy()\n",
" best_img = sample[0].numpy()\n",
" best_mape = mape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ff73a2d5-56b6-4636-8729-a71b69ed5503",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.399060</td>\n",
" <td>1.979731</td>\n",
" <td>0.182714</td>\n",
" <td>0.642942</td>\n",
" <td>0.872402</td>\n",
" <td>0.816624</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.638751</td>\n",
" <td>0.875072</td>\n",
" <td>0.104177</td>\n",
" <td>0.219567</td>\n",
" <td>0.101818</td>\n",
" <td>0.117307</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.492277</td>\n",
" <td>0.624609</td>\n",
" <td>0.060600</td>\n",
" <td>-1.963828</td>\n",
" <td>0.092951</td>\n",
" <td>0.060861</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.907458</td>\n",
" <td>1.280065</td>\n",
" <td>0.126618</td>\n",
" <td>0.535003</td>\n",
" <td>0.835822</td>\n",
" <td>0.758542</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.266889</td>\n",
" <td>1.857223</td>\n",
" <td>0.158806</td>\n",
" <td>0.694635</td>\n",
" <td>0.902473</td>\n",
" <td>0.846662</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.737844</td>\n",
" <td>2.467742</td>\n",
" <td>0.204469</td>\n",
" <td>0.799663</td>\n",
" <td>0.941083</td>\n",
" <td>0.901914</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>5.556955</td>\n",
" <td>8.550211</td>\n",
" <td>1.397686</td>\n",
" <td>0.983735</td>\n",
" <td>0.995918</td>\n",
" <td>0.992068</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 1.399060 1.979731 0.182714 0.642942 0.872402 \n",
"std 0.638751 0.875072 0.104177 0.219567 0.101818 \n",
"min 0.492277 0.624609 0.060600 -1.963828 0.092951 \n",
"25% 0.907458 1.280065 0.126618 0.535003 0.835822 \n",
"50% 1.266889 1.857223 0.158806 0.694635 0.902473 \n",
"75% 1.737844 2.467742 0.204469 0.799663 0.941083 \n",
"max 5.556955 8.550211 1.397686 0.983735 0.995918 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.816624 \n",
"std 0.117307 \n",
"min 0.060861 \n",
"25% 0.758542 \n",
"50% 0.846662 \n",
"75% 0.901914 \n",
"max 0.992068 "
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e3861bd5-cfef-458c-a3f0-97635f99b981",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 4, 1)\n",
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 2)\n",
" plt.imshow(masked_feature, cmap='gray')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 3)\n",
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 4)\n",
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.savefig('./figures/result/30_samples.png', bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "dd306e5c-7251-4385-b096-b189d0146e0a",
"metadata": {},
"outputs": [],
"source": [
"best_mask_cp = np.where(best_mask == 0, np.nan, best_mask)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "40da0e1c-04de-4523-9caf-ab85b5b474e7",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAADeCAYAAAAJtZwyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eZwtV1kujj9rrarau/fpM+ZkIDmZDglJmBLCEAmBK4kiIMpVuYr4FeSqBJnMRa6KDKLARQH9RYwIXBS4zF41KFGBKxKFYEBmAkQwMRBIQuah0727ag2/P971rqF27T67z+k+pzun3nxOdu/aNawa3lrv8LzPK5xzDr300ksvvfTSSy+99NJLL730sklFHuoB9NJLL7300ksvvfTSSy+99NLLgUjv2PbSSy+99NJLL7300ksvvfSyqaV3bHvppZdeeumll1566aWXXnrZ1NI7tr300ksvvfTSSy+99NJLL71saukd21566aWXXnrppZdeeumll142tfSObS+99NJLL7300ksvvfTSSy+bWnrHtpdeeumll1566aWXXnrppZdNLb1j20svvfTSSy+99NJLL7300sumlt6x7aWXXnrppZdeeumll1566WVTSzHritc+7qEAAD3WcNZB1xbW5utICZRDhWq+gvKf5VyBwe45iIGCOnoLRCkhRyVQScitc0BVAoWijeuG/lWl/1cARQFIET+lBKwFxjWgNVBr+r44BjSNDQDQWDhrIaT33ZWAKCT9zusYR8vn52gMfIzFMY3DWjjrYO9ahr2npm2sg7u3gV2o4RoLNBamMRjfsQwAEFIAAJx1YSxCClTzJcpRCWvooslSQYwKiFJBDBXEsIDaOaTrMlfSeRYqHBMAMD+iazKaA4YVLbc2rCvm5oC5HUBRAXPbACEBVdHvwl8HZ5HdOFPTsnATC/pnaqBezNfn9VRB++Pv1tLfQtKxZEFjUBVsUcA5C+M0rLNYNotobI2F5k4Yp8Nhta1hncNCM8ZYN7DOwToLCwdtLbQ1WNTL/nI4SCEg/Tlpa/LnUIiwHgBUqkAli+x3KSQKKTEqBtk2LLwty931EhaaMZaNRmPjLblr2eK7CwaNdVj2w3jr+W/HRhbROtdeelkP2bFjB+bn5/Ebv/EbeMELXnBQjtk0DW677TZ8+9vfxoUXXogvf/nLB7Q/13oP9NLL4S4fuuYiaD/389xrnIMSIsytUghI0DwjhYR2JszpLNY5mES/SinDtj95/z86iGd06KW3r3v7+mDZ14tNjdpqaMs66fxloONZOEhE+5pt49SuTu3jTKdBy3l76d8JhVSQQmCoShRSrbt9PbNjyw+UKiWsdVDWQVoHWUgIKaBKCVkqDLZVmNs9h2KugNpaQc5XKPZshRgWEPxAsbJt20IPOyuePwasI6VKHz4WKQFtgMLEda1IlDJxcoyjnLSKygDraDkQlDEoXeUvBx+zbgBrIIYKEhUpc2OAYQFhHYT/rhqJOSngrINZNnDGQfu7wNfNLBtSRn/sYuSglIBQEqJUgBRwy4amgi38AMt4DaQkZStUvE5SAFIBwwFENaKHnf+xAqVKB0wqiSsA6xUgVUBWWqsBM87vgbUALP3GiumsP25FfxtNxwc9+MZqWFgIIaFEASkknJOwsHDOeiWyKIRCIS1qo2GRK1AlC2hnYf0kCbTe/C2RQiRK61q/kfJVUgUFTEVbE5SUjq0wKipYZ7GoTbgMUgDbBxKNcWhsbwT30gsAFEWB5z73uXjqU5+KPXv2HLTjfvvb38Zv/dZv4Vvf+hb+4z/+46Adt5deDheRQkK25maVzJ9k6Mo4fzpLQeqWI5tOl6Uk45fm/8MPSNjb1719ne1jnezrNLjEy2gfURk5IGX9dm0bOnWAu0QiOrKk0xLz1RCFUMHOXm/7embHVpYSzjgIJemV5aMzxUBBKAk1UJClRDlfodo2oId1+wByawW500dsqkTxpACGg7hMCoiyAoSE0zU99ABtl54Qn62UgPR/W9B3RGV0adRECq+gyX5UoqzZifIDH9cVUgKlo9tgLClsKcN3UUpIFEBjYZtuZ8ua3BFjBUzFGQvYjpd6eDEl/5KxClVEZeNoj2jth5XKJYoXTzBXuvRaOBl/5+3SfdlEAV1B31WyPp+/P3cJCQNAQEJMOS5Hefi5TyNCMptAXVwXk87rLMKZW9m6Xta5bGyFVCjCC8JkQblS0sSuTHyn99LLZpS5uTlUVXXA+6mqCqeffjp+4Ad+YA1GNSl1XWNpaWli+U033YTPf/7zuPbaa1e9z6IoMBqNekRDL7205B+v/82AnALIeJ1m2AKpURwNZxMMY1rHOOAXznjz+g58k0hvX/f29VT7Ot3PFLGtJI9Yocp0Jb2dWDfYvN37mEwYxewuJ4wKqVDJAoWMQZT1tq9ndmyHO4fZd2cIRqAGBVRJiqcGCsWuOcjdc5DDAmJrSWn/+ZGP2vjIzWiYRZYERz2qEVAOIZoxoBbjA8Q33Vm4xaUIEbCOoBJaA3UDp01QLiFlFk0CkEWSIAUEvwB4f+GqqADJoP1ourrLFHkRSsJKQRGlUhJkYqwBKaAMKZ/wESaGS1DEjaJvQgoUwwKiTKJoyRhdrSnSxZE3/hwOMkiHKAmOgKICymG8XqlyGIYz+KiR0cmd90qqCkDX9M9aOl8gXhMp6Z2R3Ifs97a0FNuBoBLOR24dLEpZQToJYzQcYkRXCoFCSNSIkAdWHpooY8SwnWWNSiWTZRxBkhPrslObKlzcZ4yGWedQqYI+ZYFSalhB8KvSCQyVD1729nAvm1jKssSFF16ICy644ID3JaXEgx/84DUYVbf84z/+I972trfBmLwM4e6778ZNN920X/t8+MMfjhe/+MUYjUZrMcReernPSCEU2Fa2zkJbQMJ5WKGdyOoQWipmaRMEa3RwVwZcHVbS29e9fd1pX0sJbNkFDOYnnhkL6yHG0b5ubE1/h0wsQY7HuoF2BtoaFEJBc5ACMWsLxEBUdpwOqHEoFUwRlb7kj23rSipsq+bIsVVkTx8s+3pmx7aco1WFEuGhAgBZ0AOlRiXEsIDcPoDaPiBoxJZBEkkS/gEqosIV/vD8sCj/AFkLKJ1He6wlXDr/nQorThrNYYXruhr+4c8iSqzI8A8TbAbTEEHRLZyMEaWw94YibFLJWIcAHyWiAUFIEZQPrHRqcnyuMTTGMF4ZISX+U6RYe5koUIr1T8UkcAjOrjo7uY2zMYoUTl7G578dRUrXWUGMV3xWOCFkiCrxw53Wzrad1nAp0HZm43HbStnl6Kbf0zrd9vEKyRAMOnGKPNG4+Jb5EhLwowT0nm0vh06UUlBKwRgz4fBNk7IsQ4ZyOBzi4Q9/OJ7ylKes5zDXRK677jr83d/9HbTW+155HyKlRFEU2LNnD570pCdh69atazDCXno5tOIW/tJDHTV9sqHMdkM5BFQFoySM01jSC2hsjSW9AG1r3F3f42vxaP6TTniHlfa/UtaWDV52atnnMc6hme3VdMDy65/6ZYz9gd/0XzY270VvX/f2dbCv+XqxDOax0NwJJQoIQehCQcUAmVPLf3O9LRBta3JqoyMqIcLFTRPtKcqiLe0sbbu0gLbPM7XDospQlSzrbV/P7NgOds8BAISHRoQH2isiFP2TWytSuoEizPpwkGP9CwUxGOUPCqf4+SEajOg7v5CdpfN3Be3P2hxCURWA1p2nnRWztx/yADnwCphEkmAF1RpIR5AM+HOuvGvl98nKZ/mlwLtWAoCkZ9XYUAtAykdKKErlx2ghGhqfawxELeGgIYZeWdh4q0oamwactBDFEAHnD/iJCxGDD+RRoBSv7yxtWw7pH0sXzKIt4br5xycosY9u8b2VBWQCiYhRJEvLxeTjl2ZnCyFRh3par6Q+Stw5LK9QEzUB8PcvWSffTkD5sQh/7lSXYH2Bfdx/IRVKH4X0YB4oQUq4P1DoXnpZCxFC4Kd/+qfxhCc8AX/7t3+LSy+9dJ/b7Nq1C8973vNw//vfHwA5eOsFHd7Ict555+Hnf/7nsXfvXgyHw31v0Esvm0nYqVGIjq0sgGIYnQwg1OcpP6+
"text/plain": [
"<Figure size 1200x600 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"visualize_rst(best_img, best_mask, best_recov*best_mask_cp, best_img * (1-best_mask) + best_recov*best_mask, '')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97163937-4d78-40fc-b385-1d27b01a0647",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}