MAE_ATMO/torch_MAE_1d_20_patch_mask....

1040 lines
214 KiB
Plaintext
Raw Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b8a8cedd-536d-4a48-a1af-7c40489ef0f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fc99a573830>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(0)\n",
"torch.random.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
"metadata": {},
"outputs": [],
"source": [
"# 计算图像数据中的最大像素值\n",
"max_pixel_value = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "342d21ee-7f31-4c37-a73b-f47cac181763",
"metadata": {},
"outputs": [],
"source": [
"class GrayScaleDataset(Dataset):\n",
" def __init__(self, data_dir):\n",
" self.data_dir = data_dir\n",
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
"\n",
" def __len__(self):\n",
" return len(self.file_list)\n",
"\n",
" def __getitem__(self, idx):\n",
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
" data = np.load(file_path)[:,:,0] / max_pixel_value\n",
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/20/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d3a25f29-b16e-4485-9f06-5378b910be6e",
"metadata": {},
"outputs": [],
"source": [
"class PatchMasking:\n",
" def __init__(self, patch_size, mask_ratio):\n",
" self.patch_size = patch_size\n",
" self.mask_ratio = mask_ratio\n",
"\n",
" def __call__(self, x):\n",
" batch_size, C, H, W = x.shape\n",
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
" num_masked = int(num_patches * self.mask_ratio)\n",
" \n",
" # 为每个样本生成独立的mask\n",
" masks = []\n",
" for _ in range(batch_size):\n",
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
" mask[:num_masked] = 1\n",
" mask = mask[torch.randperm(num_patches)]\n",
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
" masks.append(mask)\n",
" \n",
" # 将所有mask堆叠成一个批量张量\n",
" masks = torch.stack(masks, dim=0)\n",
" masks = torch.unsqueeze(masks, dim=1)\n",
" \n",
" # 应用mask到输入x上\n",
" masked_x = x * (1- masks.float())\n",
" return masked_x, masks"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "41da7319-9795-441d-bde8-8cf390365099",
"metadata": {},
"outputs": [],
"source": [
"train_dir = './out_mat/96/train/'\n",
"train_dataset = GrayScaleDataset(train_dir)\n",
"val_dir = './out_mat/96/valid/'\n",
"val_dataset = GrayScaleDataset(val_dir)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
"\n",
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "645114e8-65a4-4867-b3fe-23395288e855",
"metadata": {},
"outputs": [],
"source": [
"class Conv(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
" super(Conv, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
" bias=False):\n",
" super(ConvBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
" norm_layer(out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "31ecf247-e98b-4977-a145-782914a042bd",
"metadata": {},
"outputs": [],
"source": [
"class SeparableBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
" super(SeparableBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
" # 分离卷积,仅调整空间信息\n",
" norm_layer(in_channels), # 对输入通道进行归一化\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
" nn.ReLU6()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" # 如果输入和输出通道不一致,进行降采样操作\n",
" self.downsample = downsample\n",
" if in_channels != out_channels or stride != 1:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(out_channels)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttentionBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
" super(MultiHeadAttentionBlock, self).__init__()\n",
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
" B, C, H, W = x.shape\n",
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
"\n",
" # Apply multihead attention\n",
" attn_output, _ = self.attention(x, x, x)\n",
"\n",
" # Apply normalization and dropout\n",
" attn_output = self.norm(attn_output)\n",
" attn_output = self.dropout(attn_output)\n",
"\n",
" # Reshape back to (B, C, H, W)\n",
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
"\n",
" return attn_output"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
"metadata": {},
"outputs": [],
"source": [
"class SpatialAttentionBlock(nn.Module):\n",
" def __init__(self):\n",
" super(SpatialAttentionBlock, self).__init__()\n",
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
"\n",
" def forward(self, x): #(B, 64, H, W)\n",
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
" return x * out #(B, C, H, W)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
"metadata": {},
"outputs": [],
"source": [
"class DecoderAttentionBlock(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(DecoderAttentionBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
" self.spatial_attention = SpatialAttentionBlock()\n",
"\n",
" def forward(self, x):\n",
" # 通道注意力\n",
" b, c, h, w = x.size()\n",
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
"\n",
" avg_out = self.conv1(avg_pool)\n",
" max_out = self.conv1(max_pool)\n",
"\n",
" out = avg_out + max_out\n",
" out = torch.sigmoid(self.conv2(out))\n",
"\n",
" # 添加空间注意力\n",
" out = x * out\n",
" out = self.spatial_attention(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c9d176a8-bbf6-4043-ab82-1648a99d772a",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" Conv(1, 32, kernel_size=3, stride=2),\n",
" \n",
" nn.ReLU(),\n",
" \n",
" SEBlock(32,32),\n",
" \n",
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
" \n",
" ResidualBlock(64,64),\n",
" \n",
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
" \n",
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
" \n",
" SEBlock(128, 128),\n",
" \n",
" Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(), \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "084f6b1e-ed3a-490b-9020-5479863e803b",
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
" model.to(device)\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for data in train_loader:\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_loss /= len(train_loader)\n",
"\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for data in val_loader:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" val_loss += loss.item()\n",
" val_loss /= len(val_loader)\n",
"\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "16673a37-02e9-4883-8288-aa0e240d6824",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 0.0185, Val Loss: 0.0199\n",
"Epoch 2, Train Loss: 0.0178, Val Loss: 0.0187\n",
"Epoch 3, Train Loss: 0.0174, Val Loss: 0.0217\n",
"Epoch 4, Train Loss: 0.0172, Val Loss: 0.0227\n",
"Epoch 5, Train Loss: 0.0167, Val Loss: 0.0180\n",
"Epoch 6, Train Loss: 0.0166, Val Loss: 0.0225\n",
"Epoch 7, Train Loss: 0.0163, Val Loss: 0.0183\n",
"Epoch 8, Train Loss: 0.0162, Val Loss: 0.0220\n",
"Epoch 9, Train Loss: 0.0161, Val Loss: 0.0181\n",
"Epoch 10, Train Loss: 0.0159, Val Loss: 0.0196\n",
"Epoch 11, Train Loss: 0.0159, Val Loss: 0.0210\n",
"Epoch 12, Train Loss: 0.0155, Val Loss: 0.0198\n",
"Epoch 13, Train Loss: 0.0154, Val Loss: 0.0212\n",
"Epoch 14, Train Loss: 0.0153, Val Loss: 0.0207\n",
"Epoch 15, Train Loss: 0.0153, Val Loss: 0.0216\n",
"Epoch 16, Train Loss: 0.0152, Val Loss: 0.0222\n",
"Epoch 17, Train Loss: 0.0152, Val Loss: 0.0225\n",
"Epoch 18, Train Loss: 0.0150, Val Loss: 0.0183\n",
"Epoch 19, Train Loss: 0.0151, Val Loss: 0.0242\n",
"Epoch 20, Train Loss: 0.0148, Val Loss: 0.0203\n",
"Epoch 21, Train Loss: 0.0148, Val Loss: 0.0211\n",
"Epoch 22, Train Loss: 0.0148, Val Loss: 0.0200\n",
"Epoch 23, Train Loss: 0.0146, Val Loss: 0.0191\n",
"Epoch 24, Train Loss: 0.0145, Val Loss: 0.0215\n",
"Epoch 25, Train Loss: 0.0145, Val Loss: 0.0196\n",
"Epoch 26, Train Loss: 0.0146, Val Loss: 0.0215\n",
"Epoch 27, Train Loss: 0.0144, Val Loss: 0.0195\n",
"Epoch 28, Train Loss: 0.0144, Val Loss: 0.0196\n",
"Epoch 29, Train Loss: 0.0143, Val Loss: 0.0182\n",
"Epoch 30, Train Loss: 0.0143, Val Loss: 0.0213\n",
"Epoch 31, Train Loss: 0.0142, Val Loss: 0.0178\n",
"Epoch 32, Train Loss: 0.0139, Val Loss: 0.0215\n",
"Epoch 33, Train Loss: 0.0135, Val Loss: 0.0171\n",
"Epoch 34, Train Loss: 0.0131, Val Loss: 0.0187\n",
"Epoch 35, Train Loss: 0.0128, Val Loss: 0.0171\n",
"Epoch 36, Train Loss: 0.0128, Val Loss: 0.0159\n",
"Epoch 37, Train Loss: 0.0127, Val Loss: 0.0170\n",
"Epoch 38, Train Loss: 0.0125, Val Loss: 0.0182\n",
"Epoch 39, Train Loss: 0.0124, Val Loss: 0.0155\n",
"Epoch 40, Train Loss: 0.0123, Val Loss: 0.0169\n",
"Epoch 41, Train Loss: 0.0122, Val Loss: 0.0160\n",
"Epoch 42, Train Loss: 0.0123, Val Loss: 0.0164\n",
"Epoch 43, Train Loss: 0.0120, Val Loss: 0.0154\n",
"Epoch 44, Train Loss: 0.0121, Val Loss: 0.0159\n",
"Epoch 45, Train Loss: 0.0119, Val Loss: 0.0152\n",
"Epoch 46, Train Loss: 0.0118, Val Loss: 0.0151\n",
"Epoch 47, Train Loss: 0.0119, Val Loss: 0.0135\n",
"Epoch 48, Train Loss: 0.0121, Val Loss: 0.0135\n",
"Epoch 49, Train Loss: 0.0118, Val Loss: 0.0162\n",
"Epoch 50, Train Loss: 0.0117, Val Loss: 0.0195\n",
"Epoch 51, Train Loss: 0.0116, Val Loss: 0.0160\n",
"Epoch 52, Train Loss: 0.0116, Val Loss: 0.0167\n",
"Epoch 53, Train Loss: 0.0116, Val Loss: 0.0149\n",
"Epoch 54, Train Loss: 0.0114, Val Loss: 0.0143\n",
"Epoch 55, Train Loss: 0.0115, Val Loss: 0.0136\n",
"Epoch 56, Train Loss: 0.0115, Val Loss: 0.0144\n",
"Epoch 57, Train Loss: 0.0115, Val Loss: 0.0158\n",
"Epoch 58, Train Loss: 0.0113, Val Loss: 0.0147\n",
"Epoch 59, Train Loss: 0.0112, Val Loss: 0.0142\n",
"Epoch 60, Train Loss: 0.0113, Val Loss: 0.0159\n",
"Epoch 61, Train Loss: 0.0112, Val Loss: 0.0153\n",
"Epoch 62, Train Loss: 0.0113, Val Loss: 0.0140\n",
"Epoch 63, Train Loss: 0.0112, Val Loss: 0.0156\n",
"Epoch 64, Train Loss: 0.0111, Val Loss: 0.0149\n",
"Epoch 65, Train Loss: 0.0112, Val Loss: 0.0154\n",
"Epoch 66, Train Loss: 0.0112, Val Loss: 0.0158\n",
"Epoch 67, Train Loss: 0.0111, Val Loss: 0.0136\n",
"Epoch 68, Train Loss: 0.0110, Val Loss: 0.0139\n",
"Epoch 69, Train Loss: 0.0110, Val Loss: 0.0142\n",
"Epoch 70, Train Loss: 0.0112, Val Loss: 0.0152\n",
"Epoch 71, Train Loss: 0.0109, Val Loss: 0.0151\n",
"Epoch 72, Train Loss: 0.0110, Val Loss: 0.0162\n",
"Epoch 73, Train Loss: 0.0110, Val Loss: 0.0162\n",
"Epoch 74, Train Loss: 0.0109, Val Loss: 0.0176\n",
"Epoch 75, Train Loss: 0.0109, Val Loss: 0.0143\n",
"Epoch 76, Train Loss: 0.0109, Val Loss: 0.0147\n",
"Epoch 77, Train Loss: 0.0108, Val Loss: 0.0141\n",
"Epoch 78, Train Loss: 0.0109, Val Loss: 0.0145\n",
"Epoch 79, Train Loss: 0.0108, Val Loss: 0.0140\n",
"Epoch 80, Train Loss: 0.0109, Val Loss: 0.0135\n",
"Epoch 81, Train Loss: 0.0108, Val Loss: 0.0145\n",
"Epoch 82, Train Loss: 0.0108, Val Loss: 0.0126\n",
"Epoch 83, Train Loss: 0.0108, Val Loss: 0.0145\n",
"Epoch 84, Train Loss: 0.0107, Val Loss: 0.0135\n",
"Epoch 85, Train Loss: 0.0108, Val Loss: 0.0140\n",
"Epoch 86, Train Loss: 0.0107, Val Loss: 0.0143\n",
"Epoch 87, Train Loss: 0.0107, Val Loss: 0.0146\n",
"Epoch 88, Train Loss: 0.0107, Val Loss: 0.0136\n",
"Epoch 111, Train Loss: 0.0094, Val Loss: 0.0120\n",
"Epoch 112, Train Loss: 0.0094, Val Loss: 0.0114\n",
"Epoch 113, Train Loss: 0.0095, Val Loss: 0.0128\n",
"Epoch 114, Train Loss: 0.0093, Val Loss: 0.0125\n",
"Epoch 115, Train Loss: 0.0093, Val Loss: 0.0124\n",
"Epoch 116, Train Loss: 0.0093, Val Loss: 0.0114\n",
"Epoch 117, Train Loss: 0.0093, Val Loss: 0.0127\n",
"Epoch 118, Train Loss: 0.0093, Val Loss: 0.0122\n",
"Epoch 119, Train Loss: 0.0093, Val Loss: 0.0116\n",
"Epoch 120, Train Loss: 0.0092, Val Loss: 0.0114\n",
"Epoch 121, Train Loss: 0.0092, Val Loss: 0.0130\n",
"Epoch 122, Train Loss: 0.0092, Val Loss: 0.0114\n",
"Epoch 123, Train Loss: 0.0093, Val Loss: 0.0113\n",
"Epoch 124, Train Loss: 0.0092, Val Loss: 0.0120\n",
"Epoch 125, Train Loss: 0.0091, Val Loss: 0.0110\n",
"Epoch 126, Train Loss: 0.0091, Val Loss: 0.0128\n",
"Epoch 127, Train Loss: 0.0091, Val Loss: 0.0129\n",
"Epoch 128, Train Loss: 0.0092, Val Loss: 0.0126\n",
"Epoch 129, Train Loss: 0.0092, Val Loss: 0.0113\n",
"Epoch 130, Train Loss: 0.0091, Val Loss: 0.0109\n"
]
}
],
"source": [
"train_model(model, train_loader, val_loader, epochs=130, criterion=criterion, optimizer=optimizer, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "efc96935-bbe0-4ca9-b11a-931cdcfc3bed",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "73a0002b-35d6-4e20-a620-5c8f5cd49296",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"best_mape = 1\n",
"best_img = None\n",
"best_mask = None\n",
"best_recov = None\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
" if mape < best_mape:\n",
" best_recov = rev_recon[i][0].numpy()\n",
" best_mask = used_mask.numpy()\n",
" best_img = sample[0].numpy()\n",
" best_mape = mape"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "589e6d80-228d-4e8a-968a-e7477c5e0e24",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>5.297680</td>\n",
" <td>6.225729</td>\n",
" <td>0.489679</td>\n",
" <td>-1.978159</td>\n",
" <td>-0.362509</td>\n",
" <td>0.352984</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>3.930302</td>\n",
" <td>4.176386</td>\n",
" <td>0.191670</td>\n",
" <td>2.447883</td>\n",
" <td>1.074637</td>\n",
" <td>0.201559</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.996953</td>\n",
" <td>1.279405</td>\n",
" <td>0.202344</td>\n",
" <td>-28.276637</td>\n",
" <td>-9.562830</td>\n",
" <td>-0.500861</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>2.103293</td>\n",
" <td>2.741658</td>\n",
" <td>0.353414</td>\n",
" <td>-2.891019</td>\n",
" <td>-0.796581</td>\n",
" <td>0.225314</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>3.190869</td>\n",
" <td>4.148710</td>\n",
" <td>0.457116</td>\n",
" <td>-1.093823</td>\n",
" <td>0.044020</td>\n",
" <td>0.365110</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>8.378542</td>\n",
" <td>9.440538</td>\n",
" <td>0.586501</td>\n",
" <td>-0.406974</td>\n",
" <td>0.355992</td>\n",
" <td>0.498017</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>21.329165</td>\n",
" <td>23.047779</td>\n",
" <td>2.242282</td>\n",
" <td>0.592645</td>\n",
" <td>0.829324</td>\n",
" <td>0.839954</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 5.297680 6.225729 0.489679 -1.978159 -0.362509 \n",
"std 3.930302 4.176386 0.191670 2.447883 1.074637 \n",
"min 0.996953 1.279405 0.202344 -28.276637 -9.562830 \n",
"25% 2.103293 2.741658 0.353414 -2.891019 -0.796581 \n",
"50% 3.190869 4.148710 0.457116 -1.093823 0.044020 \n",
"75% 8.378542 9.440538 0.586501 -0.406974 0.355992 \n",
"max 21.329165 23.047779 2.242282 0.592645 0.829324 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.352984 \n",
"std 0.201559 \n",
"min -0.500861 \n",
"25% 0.225314 \n",
"50% 0.365110 \n",
"75% 0.498017 \n",
"max 0.839954 "
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "5f8b2dc4-5ac4-4b52-9dea-de8d29cba6b5",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "755abc3e-f4d2-4056-b01b-3fb085f95f19",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, './models/MAE/final_patch_20.pt')"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "76449691-74b2-43ef-b092-f71cd8116448",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 4, 1)\n",
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 2)\n",
" plt.imshow(masked_feature, cmap='gray')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 3)\n",
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 4)\n",
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" # plt.savefig('./figures/result/20_samples.png')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "82467932-3b38-4d2d-83d9-8d76c4f98a06",
"metadata": {},
"outputs": [],
"source": [
"best_mask_cp = np.where(best_mask == 0, np.nan, best_mask)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "6bb568d1-07bd-49c4-9056-9ad2f2dd36a8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAADeCAYAAAAJtZwyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9f6xsW3YWin1zzLlWrapd58c9fbv7dtsYzH1uCwuSSEiOTYxEFAi8CEJiZAmCsYXtGGNZCT8eFgLlHxACYgtF8Gw//IxBIQ/y5ATjYGPrAQ7YlsEtiPXc3bLsxn5p4/TP2+eefU6dqlVrzjlm/hhzzDnX2nXu3bd9u+893WscHdWu2lWr1u89vvF94xsmpZSwxhprrLHGGmusscYaa6yxxhrPadBbvQJrrLHGGmusscYaa6yxxhprrPEbiRXYrrHGGmusscYaa6yxxhprrPFcxwps11hjjTXWWGONNdZYY4011niuYwW2a6yxxhprrLHGGmusscYaazzXsQLbNdZYY4011lhjjTXWWGONNZ7rWIHtGmusscYaa6yxxhprrLHGGs91rMB2jTXWWGONNdZYY4011lhjjec6VmC7xhprrLHGGmusscYaa6yxxnMdK7BdY4011lhjjTXWWGONNdZY47kOd9s3fu8v/GkAgOcETsD1mTHGhHOU18gA1hh0BHRkQAborIGPCdcTl+XsnMH7XuiwcwaD7eDIlt9dTyccpgTPCZ6BjoCNM7BGluc54egTdp3By3fvo7cOZASbj2ECI6EnBzIGjizIGBymEY+nEyYOOAaGNbJ+vXW422/hjMXgOpAxmGIAp5SXa7BzG/TWwZkevR3gqEdPAxIYnBgd9di5u7DkQJGBxACH/Fi3GYkB6wBDiF2Pp+ExzuGIUzyAU/O+JlJiMBgEQm8HWOOwsTt01COmAE6MmAJiCuW9KTE8T4gp4OiPmGLAMUwIHDFGj8CMYzhjjLIfPcu+HKxBR7LPHk8BnzhGAIA1AOVjag2w6ygfZ4AM8nvkc84QeuvQW4f7/Q4PhvvYuB3udg9k3d0OBAL8CACIzs22vaMeCBNwPsh7wgTEab5TDMm+TAzEgHR6Kq/ryjx+CoQAEAFEMO94ARjuAt0A9DvA9fKYY4xHPDp/suyzlBiW5JI4xyMCRzyeThiDx8PzAY/OHr/yKOAXPz3i4Rjw8cOEfW/x0lUPSwY2r8Y/+UP/19e7nN42cX19ja//+q/Hj/7oj77Vq7LGGm+7SCm91avwzPjlR98NAOhowJfe/Y63eG3WWGONzzT++f/3vwAZQk8WZAiOCGQIZAzIGHC+D5ExcMaWHBcAODHIEDa2hzEEAsEYQmryK80PNdr3kSEYECw59DQAAKxx+fuovFeXA6B8Vn/2aULgCZwYged5W0xh9qjhqIc18p29HTDYHSxDcsDDK5Lnhank1YkZCLEuQJ+HAHC+Tw890Hcw/SLnIyeP1iGCSw6t20qGYEFADJJ3hvzfj5K7d4Msw/UAObBzSIkxxbHskzamOCKmgIlHBJ4QWXJ1zx5TDAgcMXHM+4vLsSUjuIgMYdf1cNTjyt3Fxu7Q2wE9DTVXjhMwHWWdwyTYY3yMFAJwHOu+4oxHuPlbpjlzzpVLhABMXj57HOVzk0fihHT0QEygd1wBu6Hsa/QdMGzq8kIAxqkeH2eBu3uYzQZ48CXAZo9rfoxTOOATx4/h4XjAo+mIx5PHqyPj4cg4hoTrM+MUGA9PHlNMuD7L+fPv//h/e2N/L+PWwJZTQkzAOcqjnBAGHSWQMfm5vK5A13OCNQYPNrLjOpsBakwYDbBzeuHm3xNh4xgbGHBCA5LlIuaUcKeLM0CrJ2dvXXkkmHJT2HWywyeOcHSebc/D8QAyhH03wBGVmwcn2aaeHBwRIgKmOMLzhMmMZRnRDtjYHSjlEyOxXAicT7SYQS4gF1U3wJLDvruPrd3jTnpQwOkUR1xPr4BTwBTlAOqJHlOAMQTP00Vg297I9DUFshMLWB9sB1jkCyeiownnaLCxchwH2xVAT0auAR/rhUAkx4MMsMnv+/Q44RwYuy5hYwFHFn0GhhOP2GKPfXc/36BCBf6GYI2DMXKzNYaA02O5SMf8SK4UAwqg5Xxj1OePHsuFM2zkZvbOd8qNR2+GgNwk9UalyyC5UQ3dgJeufgtO4YCPHn8VYxjxcDwgcCyFkRe3d7DZ7UBPPo7Aj/BFe4uN2+LoGdfTDh0BdzrC9cT47z/xFKdwuVCxxhprrPFmxsZKkc6swqs11niuY+ckT9VcmlMCp4jAEYyUn1dw2wJczaEHG2aAtw0yRpaVc1vNnysJ5EBMiBQE5Bk3A7QKYhUcm/wejZZk8RnYKgGkr+tz3Y6OOhhD6KiHiz3Odoet3WO7uy85nubRGeCaxJXsyLl1Gk+SlIYo/zllMJU/q7kiUPLJKRwzyJyQwGU7dJu7rofthUwqy7AODM4gdcI0PZ6BeC0QlP3BAQxG5Pw9ieHZZ0ArObkej4mbfB+mOUYGjoTc0v0bbUDgXogizY9jAOxYwLihgDT5CjLL3weuINZZ+dlZwLkKdDkJsGUW8JqfG2aYcZLl7Xcl54azFdxqtKQeMxAg4JgMjB8Bcthu9gCAq26LwALAnTnDZnxlz4xjMAAIW0ey7gAi367QfGtgC1TgqkENc0fN7zjV4oA1FdAqAxiTvofRqqHlgpINEJawMq8CVg1AAGG+Iu0JRagXtp4cvXVgJDhjCxBmJJwjoyM9wbiA4ZAiHGy5AFOu8Jh8gsn3SIUrpgBKBEs9oACXhVEs7C1QmUYOoORA1KNDXy+wtrpWflbw7vM+oHJzaEGsNQ4WcuHJ79NiOSj7RMFrSAwgZDbc5CqhhWPGYANiSgBMfrwcPgJjBDpK6PIJ0N6Yy3Yv90V+nRS0AnLD8iPgRyQ/wfSvkazp5zg11SiW4kG/E2BcigsTEAkwudAQRsD2zUaM6N1Q9tfEUljoM/vvqMfW7bFzPXpyuLeRc/boDDoSBcDGGXg2mGLCyT8fwDalhKdPn+L6+hre+7d6ddZYY403GNY4fPH+W9/q1VhjjTV+g+FIc71KrnBihAwEAwvInQNSvghyeZGzzcGy5rkVJItKEUggmERImpOnyuxikQYui2mpyUnbHJXB4EzEKDhXIMMpNdstAM4aB8sOfTcUEgSJn/lo+iCrxgmgnHsxIzELEF4oIrn5V9SOGXNEIzl+SgJ2LTk418t28SggNU0FvKfE8EmArYVDSrUIUJafAZnm5Tf/cwV2ZMFIgj8Nldc1zw88ybEjQuAJjnqQ7WsRwEHAOyBgVdWLYMWF+YSgGag1CkrJSbFAmS0FxvocAIKtgLb5b6wr+zpRZsBYjkUB05wKBrJmB2skv1alZ88RG0vY2ITOigKSyaC3hCneDtBq3BrYvrgVYDTGKkX2nBCTsKsKXGMC2tx+CYY5tfLlCR1NZZs1yAh7K+8N8Aycg7xB5LOEXbcBmQ6D6+BMBbLtCaOVLkcWfUoIlgv13xuC66TKMrhuBogDR4QUcQxnAbn5pqE3DEciu+WQLwDjsHV7dNSjv3oggOr0qFZ7DAmT6PrK6uYblO0GjOkIToyt2yMlxp1OQOopHMrNAhAWFAy01TsAYOORqO70wjYbW5huR1SKA4G5VoUyu82JcfTnDHjrcQAMjp4RE3D0egzkmHUk0vJ9b7B1XfmuMXqExNi6Y6l2zSKxMLQKPtvHlC8EBcT6fmVwbQ8MeyAGmN/cV7BqCNjdz7KTLHm2eX8rY2tIpMn9Drh6UNbNGOBu9wBbt8ed/i5SYpHJkMPWynF9z/69uLfZ46k/4fEk/z92POIcEp54USg82Dpsu+eDPXn48CH+xt/4G/jABz6An//5n3+rV2eNNdZ4g2HM83GvWWONNV47NAdThnYMld07Ry5kUH1/JX+kXc4ipFjyuZvLn8uZlcwQUEwZNBv0NhRpMoCZ3LgNSw4Gc5aSG1ZWgRgn2ZYWpCvhoqRVTx6OrCgRMyDcd/dhyMH2g3x3w9yWnI4lLzRuAvpRWMo2WmCb8/DIIpn2PGHKYBUQ4Knbo8yt5vYtG61qSWV
"text/plain": [
"<Figure size 1200x600 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"visualize_rst(best_img, best_mask, best_recov*best_mask_cp, best_recov, '')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb78ecea-809d-40c6-940f-c72cd956ff84",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}