1094 lines
191 KiB
Plaintext
1094 lines
191 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import cv2\n",
|
|||
|
"import pandas as pd"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Maximum pixel value in the dataset: 107.49169921875\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 定义函数来找到最大值\n",
|
|||
|
"def find_max_pixel_value(image_dir):\n",
|
|||
|
" max_pixel_value = 0.0\n",
|
|||
|
" for filename in os.listdir(image_dir):\n",
|
|||
|
" if filename.endswith('.npy'):\n",
|
|||
|
" image_path = os.path.join(image_dir, filename)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)\n",
|
|||
|
" max_pixel_value = max(max_pixel_value, image[:, :, 0].max())\n",
|
|||
|
" return max_pixel_value\n",
|
|||
|
"\n",
|
|||
|
"# 计算图像数据中的最大像素值\n",
|
|||
|
"image_dir = './out_mat/96/train/' \n",
|
|||
|
"max_pixel_value = find_max_pixel_value(image_dir)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"checkpoint before Generator is OK\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
|||
|
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
|||
|
"\n",
|
|||
|
"# 实例化数据集和数据加载器\n",
|
|||
|
"image_dir = './out_mat/96/train/'\n",
|
|||
|
"mask_dir = './out_mat/96/mask/10/'\n",
|
|||
|
"\n",
|
|||
|
"print(f\"checkpoint before Generator is OK\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "41da7319-9795-441d-bde8-8cf390365099",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"dataset = NO2Dataset(image_dir, mask_dir)\n",
|
|||
|
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
|
|||
|
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
|
|||
|
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
|
|||
|
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 3, 1)\n",
|
|||
|
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Input\")\n",
|
|||
|
" plt.subplot(1, 3, 2)\n",
|
|||
|
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Masked\")\n",
|
|||
|
" plt.subplot(1, 3, 3)\n",
|
|||
|
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Recovery\")\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "645114e8-65a4-4867-b3fe-23395288e855",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Conv(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
|
|||
|
" super(Conv, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ConvBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
|
|||
|
" bias=False):\n",
|
|||
|
" super(ConvBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
|
|||
|
" norm_layer(out_channels),\n",
|
|||
|
" nn.ReLU()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "31ecf247-e98b-4977-a145-782914a042bd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SeparableBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
|
|||
|
" super(SeparableBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
|
|||
|
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
|
|||
|
" # 分离卷积,仅调整空间信息\n",
|
|||
|
" norm_layer(in_channels), # 对输入通道进行归一化\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
|
|||
|
" nn.ReLU6()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ResidualBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
|
|||
|
" super(ResidualBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
|||
|
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
|||
|
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
|
|||
|
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
"\n",
|
|||
|
" # 如果输入和输出通道不一致,进行降采样操作\n",
|
|||
|
" self.downsample = downsample\n",
|
|||
|
" if in_channels != out_channels or stride != 1:\n",
|
|||
|
" self.downsample = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
|
|||
|
" nn.BatchNorm2d(out_channels)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" identity = x\n",
|
|||
|
" if self.downsample is not None:\n",
|
|||
|
" identity = self.downsample(x)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv1(x)\n",
|
|||
|
" out = self.bn1(out)\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv2(out)\n",
|
|||
|
" out = self.bn2(out)\n",
|
|||
|
"\n",
|
|||
|
" out += identity\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
" return out\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Mlp(nn.Module):\n",
|
|||
|
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" out_features = out_features or in_features\n",
|
|||
|
" hidden_features = hidden_features or in_features\n",
|
|||
|
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
|||
|
"\n",
|
|||
|
" self.act = act_layer()\n",
|
|||
|
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
|||
|
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.fc1(x)\n",
|
|||
|
" x = self.act(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" x = self.fc2(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" return x"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class MultiHeadAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
|
|||
|
" super(MultiHeadAttentionBlock, self).__init__()\n",
|
|||
|
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
|
|||
|
" self.norm = nn.LayerNorm(embed_dim)\n",
|
|||
|
" self.dropout = nn.Dropout(dropout)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
|
|||
|
" B, C, H, W = x.shape\n",
|
|||
|
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
|
|||
|
"\n",
|
|||
|
" # Apply multihead attention\n",
|
|||
|
" attn_output, _ = self.attention(x, x, x)\n",
|
|||
|
"\n",
|
|||
|
" # Apply normalization and dropout\n",
|
|||
|
" attn_output = self.norm(attn_output)\n",
|
|||
|
" attn_output = self.dropout(attn_output)\n",
|
|||
|
"\n",
|
|||
|
" # Reshape back to (B, C, H, W)\n",
|
|||
|
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
|
|||
|
"\n",
|
|||
|
" return attn_output"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SpatialAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(SpatialAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x): #(B, 64, H, W)\n",
|
|||
|
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
|
|||
|
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
|
|||
|
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
|
|||
|
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
|
|||
|
" return x * out #(B, C, H, W)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class DecoderAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels):\n",
|
|||
|
" super(DecoderAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
|
|||
|
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
|
|||
|
" self.spatial_attention = SpatialAttentionBlock()\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # 通道注意力\n",
|
|||
|
" b, c, h, w = x.size()\n",
|
|||
|
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
|
|||
|
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
|
|||
|
"\n",
|
|||
|
" avg_out = self.conv1(avg_pool)\n",
|
|||
|
" max_out = self.conv1(max_pool)\n",
|
|||
|
"\n",
|
|||
|
" out = avg_out + max_out\n",
|
|||
|
" out = torch.sigmoid(self.conv2(out))\n",
|
|||
|
"\n",
|
|||
|
" # 添加空间注意力\n",
|
|||
|
" out = x * out\n",
|
|||
|
" out = self.spatial_attention(out)\n",
|
|||
|
" return out"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SEBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, reduced_dim):\n",
|
|||
|
" super(SEBlock, self).__init__()\n",
|
|||
|
" self.se = nn.Sequential(\n",
|
|||
|
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
|||
|
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
|||
|
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return x * self.se(x)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 定义Masked Autoencoder模型\n",
|
|||
|
"class MaskedAutoencoder(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(MaskedAutoencoder, self).__init__()\n",
|
|||
|
" self.encoder = nn.Sequential(\n",
|
|||
|
" Conv(1, 32, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(32,32),\n",
|
|||
|
" \n",
|
|||
|
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" ResidualBlock(64,64),\n",
|
|||
|
" \n",
|
|||
|
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(128, 128)\n",
|
|||
|
" )\n",
|
|||
|
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(32),\n",
|
|||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(16),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
|||
|
" nn.Sigmoid()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" encoded = self.encoder(x)\n",
|
|||
|
" decoded = self.decoder(encoded)\n",
|
|||
|
" return decoded\n",
|
|||
|
"\n",
|
|||
|
"# 实例化模型、损失函数和优化器\n",
|
|||
|
"model = MaskedAutoencoder()\n",
|
|||
|
"criterion = nn.MSELoss()\n",
|
|||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "e9c804e0-6f5c-40a7-aba7-a03a496cf427",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def masked_mse_loss(preds, target, mask):\n",
|
|||
|
" loss = (preds - target) ** 2\n",
|
|||
|
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
|||
|
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
|||
|
" return loss"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 训练函数\n",
|
|||
|
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
|||
|
" model.train()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" # loss = criterion(reconstructed, y)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 评估函数\n",
|
|||
|
"def evaluate(model, device, data_loader, criterion):\n",
|
|||
|
" model.eval()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" if batch_idx == 8:\n",
|
|||
|
" rand_ind = np.random.randint(0, len(y))\n",
|
|||
|
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 数据准备\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"print(device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"model = model.to(device)\n",
|
|||
|
"\n",
|
|||
|
"num_epochs = 150\n",
|
|||
|
"train_losses = list()\n",
|
|||
|
"val_losses = list()\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" train_loss = train_epoch(model, device, dataloader, criterion, optimizer)\n",
|
|||
|
" train_losses.append(train_loss)\n",
|
|||
|
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
|||
|
" val_losses.append(val_loss)\n",
|
|||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
|||
|
"\n",
|
|||
|
"# 测试模型\n",
|
|||
|
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
|||
|
"print(f'Test Loss: {test_loss}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 39,
|
|||
|
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "NameError",
|
|||
|
"evalue": "name 'train_losses' is not defined",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|||
|
"Cell \u001b[0;32mIn[39], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tr_ind \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(\u001b[43mtrain_losses\u001b[49m)))\n\u001b[1;32m 2\u001b[0m val_ind \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(val_losses)))\n\u001b[1;32m 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(train_losses[\u001b[38;5;241m1\u001b[39m:], label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
|||
|
"\u001b[0;31mNameError\u001b[0m: name 'train_losses' is not defined"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"tr_ind = list(range(len(train_losses)))\n",
|
|||
|
"val_ind = list(range(len(val_losses)))\n",
|
|||
|
"plt.plot(train_losses[1:], label='train_loss')\n",
|
|||
|
"plt.plot(val_losses[1:], label='val_loss')\n",
|
|||
|
"plt.legend(loc='best')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "0bb9e09a-d317-49a6-b413-f0159539ac86",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"torch.save(model, './models/MAE/final_10.pt')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "858b0940-fa98-4863-a1e4-2f5603b5c19d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"model = torch.load('./models/MAE/final_10.pt')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "6d8fddd7-8728-43ec-8c72-bd068f0002d4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
|||
|
" data_label = data_label[mask_rev==1]\n",
|
|||
|
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
|||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"id": "d5b01834-ca18-4ec3-bc9d-64382d0fab34",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.092253</td>\n",
|
|||
|
" <td>1.720153</td>\n",
|
|||
|
" <td>0.134480</td>\n",
|
|||
|
" <td>0.932102</td>\n",
|
|||
|
" <td>0.981798</td>\n",
|
|||
|
" <td>0.966130</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.078788</td>\n",
|
|||
|
" <td>0.170601</td>\n",
|
|||
|
" <td>0.009332</td>\n",
|
|||
|
" <td>0.012611</td>\n",
|
|||
|
" <td>0.003674</td>\n",
|
|||
|
" <td>0.006333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.961178</td>\n",
|
|||
|
" <td>1.460930</td>\n",
|
|||
|
" <td>0.118522</td>\n",
|
|||
|
" <td>0.891069</td>\n",
|
|||
|
" <td>0.969307</td>\n",
|
|||
|
" <td>0.945610</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.039787</td>\n",
|
|||
|
" <td>1.628992</td>\n",
|
|||
|
" <td>0.127926</td>\n",
|
|||
|
" <td>0.928341</td>\n",
|
|||
|
" <td>0.980775</td>\n",
|
|||
|
" <td>0.964158</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.094167</td>\n",
|
|||
|
" <td>1.696594</td>\n",
|
|||
|
" <td>0.134129</td>\n",
|
|||
|
" <td>0.934846</td>\n",
|
|||
|
" <td>0.982724</td>\n",
|
|||
|
" <td>0.967207</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.121535</td>\n",
|
|||
|
" <td>1.772920</td>\n",
|
|||
|
" <td>0.139542</td>\n",
|
|||
|
" <td>0.940266</td>\n",
|
|||
|
" <td>0.984235</td>\n",
|
|||
|
" <td>0.970309</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>1.388302</td>\n",
|
|||
|
" <td>2.337924</td>\n",
|
|||
|
" <td>0.161247</td>\n",
|
|||
|
" <td>0.950963</td>\n",
|
|||
|
" <td>0.987079</td>\n",
|
|||
|
" <td>0.975622</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa r\n",
|
|||
|
"count 75.000000 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
|||
|
"mean 1.092253 1.720153 0.134480 0.932102 0.981798 0.966130\n",
|
|||
|
"std 0.078788 0.170601 0.009332 0.012611 0.003674 0.006333\n",
|
|||
|
"min 0.961178 1.460930 0.118522 0.891069 0.969307 0.945610\n",
|
|||
|
"25% 1.039787 1.628992 0.127926 0.928341 0.980775 0.964158\n",
|
|||
|
"50% 1.094167 1.696594 0.134129 0.934846 0.982724 0.967207\n",
|
|||
|
"75% 1.121535 1.772920 0.139542 0.940266 0.984235 0.970309\n",
|
|||
|
"max 1.388302 2.337924 0.161247 0.950963 0.987079 0.975622"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"id": "d0a8f2f8-6e44-4b01-a390-1b80c4059d5f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"best_mape = 1\n",
|
|||
|
"best_img = None\n",
|
|||
|
"best_mask = None\n",
|
|||
|
"best_recov = None\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
|
|||
|
" if mape < best_mape:\n",
|
|||
|
" best_recov = rev_recon[i][0].numpy()\n",
|
|||
|
" best_mask = used_mask.numpy()\n",
|
|||
|
" best_img = sample[0].numpy()\n",
|
|||
|
" best_mape = mape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"id": "65ffcff5-4b1f-4d52-878f-c7323ce895c9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.081396</td>\n",
|
|||
|
" <td>1.521876</td>\n",
|
|||
|
" <td>0.134078</td>\n",
|
|||
|
" <td>0.740742</td>\n",
|
|||
|
" <td>0.915243</td>\n",
|
|||
|
" <td>0.874727</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.503237</td>\n",
|
|||
|
" <td>0.752737</td>\n",
|
|||
|
" <td>0.082727</td>\n",
|
|||
|
" <td>0.203443</td>\n",
|
|||
|
" <td>0.083132</td>\n",
|
|||
|
" <td>0.099583</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.354667</td>\n",
|
|||
|
" <td>0.447099</td>\n",
|
|||
|
" <td>0.043627</td>\n",
|
|||
|
" <td>-1.035759</td>\n",
|
|||
|
" <td>-0.034988</td>\n",
|
|||
|
" <td>0.159654</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>0.709801</td>\n",
|
|||
|
" <td>0.979624</td>\n",
|
|||
|
" <td>0.093325</td>\n",
|
|||
|
" <td>0.672060</td>\n",
|
|||
|
" <td>0.898100</td>\n",
|
|||
|
" <td>0.842776</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>0.983843</td>\n",
|
|||
|
" <td>1.372613</td>\n",
|
|||
|
" <td>0.118378</td>\n",
|
|||
|
" <td>0.793777</td>\n",
|
|||
|
" <td>0.939622</td>\n",
|
|||
|
" <td>0.901160</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.329353</td>\n",
|
|||
|
" <td>1.873745</td>\n",
|
|||
|
" <td>0.152530</td>\n",
|
|||
|
" <td>0.871634</td>\n",
|
|||
|
" <td>0.964235</td>\n",
|
|||
|
" <td>0.939917</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>6.657323</td>\n",
|
|||
|
" <td>12.205771</td>\n",
|
|||
|
" <td>1.874481</td>\n",
|
|||
|
" <td>0.991835</td>\n",
|
|||
|
" <td>0.997919</td>\n",
|
|||
|
" <td>0.996090</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 1.081396 1.521876 0.134078 0.740742 0.915243 \n",
|
|||
|
"std 0.503237 0.752737 0.082727 0.203443 0.083132 \n",
|
|||
|
"min 0.354667 0.447099 0.043627 -1.035759 -0.034988 \n",
|
|||
|
"25% 0.709801 0.979624 0.093325 0.672060 0.898100 \n",
|
|||
|
"50% 0.983843 1.372613 0.118378 0.793777 0.939622 \n",
|
|||
|
"75% 1.329353 1.873745 0.152530 0.871634 0.964235 \n",
|
|||
|
"max 6.657323 12.205771 1.874481 0.991835 0.997919 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.874727 \n",
|
|||
|
"std 0.099583 \n",
|
|||
|
"min 0.159654 \n",
|
|||
|
"25% 0.842776 \n",
|
|||
|
"50% 0.901160 \n",
|
|||
|
"75% 0.939917 \n",
|
|||
|
"max 0.996090 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"id": "67fbca5e-faec-48db-901c-c3105bf60492",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"best_mask_cp = np.where(best_mask == 0, np.nan, best_mask)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "098ceaa3-e072-431d-8e42-5d5b988e2628",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.image.AxesImage at 0x7f700ec465b0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGgCAYAAADsNrNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqRklEQVR4nO3de3zU1Z3/8ffcJ9cJF5kQSSQqj+J1q9yMuHUfmpVtbS2VtXV/WFFb0RqqgFahLvRnFYNutRaLotaC+6iKslutq7/V9Re3tFjk5mW9Al35FaomyGIygdxnvr8/qJM5JxcIJJwJvJ6PRx6POfP9zszJNyFvvt/P95zj8zzPEwAAh5nfdQcAAEcnAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4MSABdDSpUs1evRoRaNRTZo0SevXrx+ojwIADEK+gZgL7qmnntLll1+uZcuWadKkSbrvvvu0atUqbd68WSNGjOj1talUSh9//LEKCgrk8/n6u2sAgAHmeZ4aGxtVUlIiv7+X8xxvAEycONGrqqpKt5PJpFdSUuJVV1fv97U7duzwJPHFF1988TXIv3bs2NHr3/ug+llbW5s2bdqk+fPnp5/z+/2qrKzU2rVru+zf2tqq1tbWdNv7ywnZe5edpYLwvu6Fjsk1XuOPRYx2MJ6XfuzLM7cpGDCaXkNTr/3v8nq/eRbmNbWqJ6nPzG3JXU1Wu7nXz25tMF/fsrsl/bi9ucPY1tZo7pvqSBntYNT80eYON49h4XGF6cf+oVFjW/jEIUbba0322vYPMY+ZLxo22hpSkNEx61euw/y+ugiFzLb9v6n8HPOzczM+K2T9LAvjZjtlHjPlFJibZW5Ppsy+hvyd36fPd64A7JNIJFRaWqqCgoJe9+v3ANq1a5eSyaTicfMfezwe1wcffNBl/+rqat12221dni8IB1X4eQBFzD9C/qjZDuZm/CHItf742QHU3vsfvC6vtwNIXo+vTbVYf7CsfibD7b1+dmvY/MMeCnX2vb3d/NzWgPl9pTyzn0Hr+84Lme3Pj63U9XiGrWPg+TusthkCfmv/LgGUGepdAsjsVxfh/QWQGZ6+vIx2yNymQjOEuwZQnrm5TwFUKACm/ZVR+j2A+mr+/PmaO3duuv15cub8VVw5Ofv++Nj/Q/flWn9YijL/h937HzSf/QetzQoF+w+ctX/gwod6fG/7k0Pd7tWzwB0XGe3IkGgPe0oNH9YbbfsMKZRj/miDuQfeG/vMzT7+fuvsyje8yHwDK7SNY2r/fOx9c3OstvWzD1i/smErVHIz+hK0jl8kX71pSZrfd6vVtrcPjRR3dmM/OQqgq34PoOHDhysQCKiurs54vq6uTsXFxV32j0QiikQiXZ4HABzZ+v027HA4rHHjxqmmpib9XCqVUk1NjSoqKvr74wAAg9SAXIKbO3euZsyYofHjx2vixIm67777tHfvXl155ZUD8XEAgEFoQALoW9/6lj799FMtXLhQtbW1+uIXv6gXX3yxy40JvQmUDVUg7y9F3qExc6NdF8jrLB57Sesmgz3WXW92zafDLPx3qVFYBfXUS1Xpx/4pS7t2/BBE//E5o93+yD+kH3f8udHYFhtj3qm2+/3/MdpeyrxpYcTKP/T4ua0/udh8bbLnGy0kyWsxj7EvbP0a2XU0u2aXKWhdfrWOty9o3dBg13zyh5vtzDqPVS9qaNvVcz8kpTzzpgO/z/w+Aj7z/Zo6EunHId//Nbb5/JW9fhaAAbwJYdasWZo1a9ZAvT0AYJBjLjgAgBMEEADACefjgHriO7FMvoK/1A4C9uBSa6Bkc+e1eDW1mPvWm7WTLoMP7XEoNrtGlDEuKPWyeYnR/7c/7/29LKnVc8zXn/tTox26+sn042T11819ZR6D/GJzEGUgcuADU1JNVl3MrgFZ9SR/vvXzsI+5XfPJrLvZNbhis4bjK7BqOsk2s+23fmUzx/3IrPO0tJn1v46U+V5hfy+1KUmRoFlvyhx4Kkltyc7vuy1ovjcDC4D94wwIAOAEAQQAcCJ7L8GVXS9f4YHNrzVYF22wL7n1JjDUnKIm1Whe8omONm9VT+22Lov1oqNub6/bQyHz/ym+uHUrtH2re6E15U3m/G9F5s/UF7Yug/nsqZCszyowl/NoS5nfZ+ZlMc+6rdq+jdpnfZZ9ic5u268fkXu5ABw8zoAAAE4QQAAAJwggAIATWVsDOhqkNt1qPtFiLXaXcQu5f7i18Jp9m/VnViWs3brdvBcFS1/pdXvrT6eZT+xnqp7ebsv2WWvu2LdR27fYK8fcngyY/2dKJs3PCmTcph2wfr3t267tGlBzxx6jbS+/cGzedwWg/3AGBABwggACADhBAAEAnKAGdAhSbywwn7CWcvAVFRltr8OaVqbL0g/WBC65ncse+KzpcFRgNrWruZeeHprInH/tdXty5Qyj7bNrWRltr92s2fjs5TOssTsqsGpCnrW/JRLoHDdkT51jj+Ox23aNKBz4u14/C8Ch4QwIAOAEAQQAcIIAAgA44fM8bz+DOg6vRCKhWCymhoYGFR7gXHDZyvvzEvOJlFm/8Jqtuo1dO9mdscxEh/XahDl/W8qqAXV8bI5psZdUyJRz+ws9busPqd/f2NkYMdTY5rPrXlHrZx6x5pUrNOeCS8qsGSUzakRdlluwlmfossyHpdkz61W5oYt63R/APgf6d5wzIACAEwQQAMAJAggA4MRRXwNKba022v4x8wf8M/uDPfbG1mWNH6sGlDmXnC9kjkfyF5l1GV/UrJX4v/bwgXazi9Tam80n7BpQkTnAyWevBzSszHpDe4n1jP9T2ct3W3O/tfnN14Z95vdp15dsQf8FvW4HjlbUgAAAWY0AAgA4QQABAJzI2rng/l/iIRVo3xo45YWzBuxzBkvNJ1VzvdH22pPmDtYaPYGh0V63++MZ12XtOkqH9d651nsdAu/jT422r6zY3CFhjV8qsv6PtPOPZtseJ5RZ97HXGrLmmQvb260aUbu11pCtJfls+nF+aGqv+9o+bDDHiB0fu76HPYEjF2dAAAAnCCAAgBMEEADAiaytAe1pb5Ta2iVJW+vvNbYV54422gXhi/vtc73Gp80nGmrN7c2d42sGun6UWj2nx23+4faCQBa7bmPPBZef0/m4xVqnyJ6Tzl636BAEpi3v0/6pV28yn8jNMdtFPa8P5LPm3lOw97nfZI058nxmzWhvR8JoRzPWHvrzHnNs1Kj8mb1+VG3T/xjtXS23Ge2J8R/13lfgCMAZEADACQIIAOBE1k7F8x+bb1Fewb5pWoZG84x9gtZSy7Hw8PTjY3IuG9D+pf57cfqxL2hdwQzal72sJRRarOUXmqy2tcTC55cg9723dRksHDLbfuv/EtbyDV0uwWVOr2Pfdp35uVKXS3T+KUuVLVIvW7foDx+SfugbETe3xaxbvu1LctbPr0Xm9/3J3g+NdjjQuX99625jm9/nM9o7m8zLd/Wt5s866Dd/vuGA+bt1Qdld6cePb77O2NbQav5srzv9QQEuMRUPACCrEUAAACcIIACAE1lbA3r23dnpGlDQZy0XYF1fL84rSj/uSJn1jJOH3jIwHe2G96efmO1dn5k72LUVu23XeTJvObb3tWs8dk3IZtd5MqffsbfZ7PqRVbtq37LLaEfm/Gvv73eYpN79sdH2jRhl7pBjXptOZfweSVKbNRXPpp3rzP0z/uk0dZj1oqYO81Z2v8zf2Q5rWqCU1W5Nmu32jGZDq7mtqcP8+dw64SEBLlEDAgBkNQIIAOAEAQQAcCJrp+IJ+gJdaj+fS1llq9q99YehR/vn7Wkyn7Dbdh3HXo7aruNk7m/XfKxxP75ic8yLZ9UkZNejUhmvt6e3sWtRf64z37vRrI0kreW/G6vOM9oFS1+RC/5TFvZp/9q9vzDadg1oV3Oj0c6s4+xpN38+7VbdrCBk/ryS+ym97mrueTlwu+bTYrV/vP4ao71wIjUhZCfOgAAAThBAAAAnCCAAgBNZWwNqaN2r9nD3U+2nZF7zbkt27mdfez+c+lpz6LLUgD23nD9j7Ejvq0P
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.imshow(best_img*best_mask_cp, cmap='RdYlGn_r')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"id": "0d26de20-dc8f-4324-8a38-a368c66e5cca",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 4, 1)\n",
|
|||
|
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" # plt.title(title + \" Input\")\n",
|
|||
|
" plt.subplot(1, 4, 2)\n",
|
|||
|
" plt.imshow(masked_feature, cmap='gray')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" # plt.title(title + \" Mask\")\n",
|
|||
|
" plt.subplot(1, 4, 3)\n",
|
|||
|
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" # plt.title(title + \" Recovery Region\")\n",
|
|||
|
" plt.subplot(1, 4, 4)\n",
|
|||
|
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
|
|||
|
" # plt.title(title + \" Recovery Result\")\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.savefig('./figures/result/10_samples.png', bbox_inches='tight')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 38,
|
|||
|
"id": "072a4712-c490-4037-94d5-e345f1fc190c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAADeCAYAAAAJtZwyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9e7hlWVUejL/zstbe5/TpqqK7gQa5yKURVAQ1chFRE1SIRvmCmhAD+kUDAUxiUB4Sv4BGwqP5EYFPJSIaIya24WcQkJBoTB6u4g9QuwEbBRoQ6ebSDdVdVX1qn73XmpffH2OOOcdce+1T51TXpat7jX6q99lrr/tac60xxvuOd6gYY8Rkk0022WSTTTbZZJNNNtlkk12ipi/2Dkw22WSTTTbZZJNNNtlkk0022Z2xKbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JK2KbCdbLLJJptssskmm2yyySab7JI2e9AZb//hbwYAtEdmUI2Gmlug0dCXt1AzA9UYqEYDRgFGIS494qIvK9CKlgGgtAIAxJVH9CHPElcesfd52bj0cGkdMUSEPsCtPELv0e32eVoMEd1uR98DEAJgLW2n2bKw2w2aLYv5vebQjUaz00LNDcxV27S/AOAjYu9pVy9voRoDNBpKK6jLGqjWAloDWgEh0kaAMs2KU6kVMJ+VT2vod2ugtraAdhuIvLwFmjmgNP0DAGPpd9elE5MOKoayHC9rbJlHaZqmNWDa9D2tt91GQMCiP4U+dOjCEj46+ODgo0Mfllj0K7jg4WJAiBEAENLfIUa46PN0+i0ixAAXAs2HWE0PMWLpI4yiUyGtT4exbTW00tCKZnCBtuFjRIj50iBEwMeI3gMLF9M8vB80z8qV5XwE+kB/9wHY6z3MYCd8oHUYBdx726LRQKMVjALuNdeYG4X7bs+w086hQct2waPzTpyDkC65RpuuxdMe/ArckyyEgOPHj2O1Wh14mfl8jiuuuAJa33Nya7fccgv+/t//+3jXu951sXdlsgNaFM+7u5rFL/4n+oPHkOuA4BAXp4DFkv51PbBc0d/8rtIaaBt6L23PadlTp6Gf8otnvS/+d/9RfhCHRU+fux3i0iEuXH6vh95DNwYwivyIxkBtW3oXh4iwdMgPfoB8Cq2gtxug1VAzC2VN2X9r0/uXp9F7FlqXT4D8D35Ha0vvxGYO2DZ9zoH5DrpI71wFei8Zld7Fy10gOKBfQt3rWWd9niabbMzCX5HPoNL9Gp0DnAd2FzR+O5c+e8Qu+R9akS/N9/x8RtN3F4BzCCdXiH1AXDnEPtD8pvhA7OPm71oBbfruaSyzj57H8MLR95Unxyr5yOCxPDfQR2dAiLRNL56fjYYyisZwo2k8aw01E+OYx60c1zwNKOOZv/N8bUPj24oxPduhedhv9i6PYcRQPttteiYEV/vXpiX/mn1vpWnd/Eknjf7NdwClcdqdggsdOp/86+gQYsDKL7DnlnAhVH609JVd8NmHLr/F7BOzrxkQ8/eVD9m3Nqpc2z49Q2dGw2o6Vxoq+fZh1MfuU/x0Rx/z/rHvDQBLx/50TNsgn3zlgS7FcXIffIzofERrFO69ZdBoBa2AxigcbRW2G4X7bm1hp5ml21mjC27gX9Nxsn+tlcJ3POj/gzPZgQPb+ZddTgPjMks38MxAGQ11eUM36iC4i7MeMQ2a2Af6PQW9PHhhFKpQY5tOWFx5esEBsIECzng6DazdHr73aLZ7hDRoYojYO74Hv/LwvUf0EabR0I2BbjTsjD4rSy9gfrnCKChDg0TvtHnQwSgoq2unQAMIqn55to34nQenAba3KKA14mXaJGciBBo4HNjKoDQGYLVLg9GnAFcNjmHMYgACAJP+9iEvF0UQ5qODC10ObOnmScGbd1UACyAHsPy3HJj1QA15uheDg60Mwnq3NRS0onPrgs/7wuvi9awUsPIRRinoiGo/lz5i5UvA26c/9lzIA8+HCKNVHoBGKxit0mBVfItit49YuIjtpsvnDAC64LDybm0AA4Bxg4O6h9iJEyfw0z/907j++usPvMzjHvc4vOxlL8PRo0fP455NNtnd2Pg9Is1rqKZFnEdyhJ0nZ3ivJJnJkTRASO+TEIvDeJYWVy6/15VWiMkx5vesX/RYnezQ5wR0hJ0ZmJlFu9OgPbkqSeOUBFeNhprNyvuZ9zEECtj5k9+/fCzSCdYKqmnLjsr3rNK0Dk4AC9NKQ8n3rdb0XuVk82STnUNTcizHAGWByPd1CBTYhoDoPOLpMpZjq6Ev0+SPAjRvTl6ldLxpaj+7DxSc+gAFIPY++cIUlDLwBIDW5YEYKEiNS4e48uh2O/QLhxgiog/QjYGdGdhtC3uv+rmUx3JKUlVgEgLgFBDomHIc0cbiQwO1Xy2Nj7XaoACJhmAQ/xYcYhrLin/n3+T8cjmAgl9tS6A7mFeDnhsBAX3ocnAbEZIPGTb41xFO+M7Sj+aAd+hj99kf57XE7JfyJWS/1So6bxaAS79bsS5eXx/Ijw4xrSf52H0gQGnl64DXh4jOB3S++PsABbgSSCIgK2JmFEKIWDjy47ftMu0nA0clsOVpfNxLf/Bn74EDW3V5Q9mVbVvQ2VZDbzUpqKtvLkV7VEc1po5qFAxiCOWF1ZT5tIiA4tIhAFCdRugDTHpxxlACW7cQQbUOOaiVgzSG+mbigaUaHjwpo2UUTeOM2NhLPw+0QWZJZsTTQFSc5TGWglbNmWOUwcToK88nX7iBBgSdNDFg5X4dJOgF6GWdkc4AwlkpUOVBIAcWnZaybong8nf5Gy8vs0HpVKegsdysQxR3bHv0PX2mbBP7P/SbygOqDzQAfeDBlz5DzOisDxEmAq0p+8Lzy+wVfITWCksfMTMuo8ouePQB6BHl85m2j8H9dQ8x5xw+/elP4y//8i+xWCzQ9/0Zl5nP5zh+/DiMMdje3j4QcrtarbBcLvP3pmmwvb19p/b9QprWGjs7Ozh69OiBz9Nkk2204ZgJWrxPxHsrxOyYAqD3W8gvgXOyK/YH/gsAwL/hh9Z/9MSs8iuH+7/1Tzau4/Q//7bs+NL7d8R3CBGUggzJQ0vvRufpnzwnfA6kowuUafw3kOcJgRzQEAPVaQ3fq2NO72ST3VmTfl2+X0PxM7W4/yWjQfrXQdybEjySKIJR5KVIpqSPQJewwj4gamJTDC2mbfvewyU2JX0PsLOQ98s0fRnHWgFNrMfy0MQxRACKaZcYoLRjvng+Pxv8B3lex6bztnRY/01uI4TxbfD1Gk6O5FfHSP8kcCT9ZzqsGsQZAkf1Z6i+s189XJeP6+dbK5XyhkpMM3l9dMmDfNTCCF+bw7nsJwfy830siG3en3T7cnDrY4SOdaAbFLDySP61IkQ5+IwIa1X75TiEf33gwNZctU203MvbRFlItJ+2AVpbtu484CjiVlYj6kLbrRBQANH7ahsqBcmq6xG3ieoQew+lGyijM1U4hoj28jZTI4IPmN9rjmYnYHliCbdwaf0RSlOGQPUB/aKnY2DKNF855sqml2aF0gKFGsE2RGaHlAlrgXlLmeJmXqhPQ8RWUo/ld6Y+2BYIllDbsRf0SKZ5bSBzdpqvo7LQWmMPu9Wgo8NSghas4aLP61qnSIy/4MeC2nRK6TQxLULR8lYbaCj6VAo6qhT0cuaJfuOBvHAr3LZcpQEWq7zJmHFG6eTKw6TT0ELD2Pq8LfoIZuY04iVychWxcjTwjELernw48PHdU+3o0aN46UtfiptvvhmvetWr8L73ve+My9xwww14/vOfj0c96lH4yZ/8Sdz3vvc94zJvfvObce211+bv3/iN34gXvvCFmM9HkKu7oB09ehT/+l//a3z2s5/Fq1/9avzxH//xxd6lyS5lcwMmz/CZzO+jtoFqOkCnchsf8zu6OMznJmCLvV9PICc
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x600 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"visualize_rst(best_img, best_mask, best_recov*best_mask_cp, best_img * (1-best_mask) + best_recov*best_mask, '')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "05458c88-9907-4b25-b32e-8d5acfb3224f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|