1298 lines
196 KiB
Plaintext
1298 lines
196 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import cv2\n",
|
|||
|
"import pandas as pd"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "b8a8cedd-536d-4a48-a1af-7c40489ef0f8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<torch._C.Generator at 0x7f8d8a9ef7f0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"np.random.seed(42)\n",
|
|||
|
"torch.random.manual_seed(42)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 计算图像数据中的最大像素值\n",
|
|||
|
"max_pixel_value = 107.49169921875"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"checkpoint before Generator is OK\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
|||
|
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
|||
|
"\n",
|
|||
|
"# 实例化数据集和数据加载器\n",
|
|||
|
"image_dir = './out_mat/96/train/'\n",
|
|||
|
"mask_dir = './out_mat/96/mask/20/'\n",
|
|||
|
"\n",
|
|||
|
"print(f\"checkpoint before Generator is OK\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "41da7319-9795-441d-bde8-8cf390365099",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_set = NO2Dataset(image_dir, mask_dir)\n",
|
|||
|
"train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
|
|||
|
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
|
|||
|
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
|
|||
|
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 3, 1)\n",
|
|||
|
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Input\")\n",
|
|||
|
" plt.subplot(1, 3, 2)\n",
|
|||
|
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Masked\")\n",
|
|||
|
" plt.subplot(1, 3, 3)\n",
|
|||
|
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Recovery\")\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "645114e8-65a4-4867-b3fe-23395288e855",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Conv(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
|
|||
|
" super(Conv, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ConvBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
|
|||
|
" bias=False):\n",
|
|||
|
" super(ConvBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
|
|||
|
" norm_layer(out_channels),\n",
|
|||
|
" nn.ReLU()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "31ecf247-e98b-4977-a145-782914a042bd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SeparableBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
|
|||
|
" super(SeparableBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
|
|||
|
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
|
|||
|
" # 分离卷积,仅调整空间信息\n",
|
|||
|
" norm_layer(in_channels), # 对输入通道进行归一化\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
|
|||
|
" nn.ReLU6()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ResidualBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
|
|||
|
" super(ResidualBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
|||
|
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
|||
|
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
|
|||
|
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
"\n",
|
|||
|
" # 如果输入和输出通道不一致,进行降采样操作\n",
|
|||
|
" self.downsample = downsample\n",
|
|||
|
" if in_channels != out_channels or stride != 1:\n",
|
|||
|
" self.downsample = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
|
|||
|
" nn.BatchNorm2d(out_channels)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" identity = x\n",
|
|||
|
" if self.downsample is not None:\n",
|
|||
|
" identity = self.downsample(x)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv1(x)\n",
|
|||
|
" out = self.bn1(out)\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv2(out)\n",
|
|||
|
" out = self.bn2(out)\n",
|
|||
|
"\n",
|
|||
|
" out += identity\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
" return out\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Mlp(nn.Module):\n",
|
|||
|
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" out_features = out_features or in_features\n",
|
|||
|
" hidden_features = hidden_features or in_features\n",
|
|||
|
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
|||
|
"\n",
|
|||
|
" self.act = act_layer()\n",
|
|||
|
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
|||
|
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.fc1(x)\n",
|
|||
|
" x = self.act(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" x = self.fc2(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" return x"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class MultiHeadAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
|
|||
|
" super(MultiHeadAttentionBlock, self).__init__()\n",
|
|||
|
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
|
|||
|
" self.norm = nn.LayerNorm(embed_dim)\n",
|
|||
|
" self.dropout = nn.Dropout(dropout)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
|
|||
|
" B, C, H, W = x.shape\n",
|
|||
|
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
|
|||
|
"\n",
|
|||
|
" # Apply multihead attention\n",
|
|||
|
" attn_output, _ = self.attention(x, x, x)\n",
|
|||
|
"\n",
|
|||
|
" # Apply normalization and dropout\n",
|
|||
|
" attn_output = self.norm(attn_output)\n",
|
|||
|
" attn_output = self.dropout(attn_output)\n",
|
|||
|
"\n",
|
|||
|
" # Reshape back to (B, C, H, W)\n",
|
|||
|
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
|
|||
|
"\n",
|
|||
|
" return attn_output"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SpatialAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(SpatialAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x): #(B, 64, H, W)\n",
|
|||
|
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
|
|||
|
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
|
|||
|
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
|
|||
|
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
|
|||
|
" return x * out #(B, C, H, W)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class DecoderAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels):\n",
|
|||
|
" super(DecoderAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
|
|||
|
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
|
|||
|
" self.spatial_attention = SpatialAttentionBlock()\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # 通道注意力\n",
|
|||
|
" b, c, h, w = x.size()\n",
|
|||
|
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
|
|||
|
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
|
|||
|
"\n",
|
|||
|
" avg_out = self.conv1(avg_pool)\n",
|
|||
|
" max_out = self.conv1(max_pool)\n",
|
|||
|
"\n",
|
|||
|
" out = avg_out + max_out\n",
|
|||
|
" out = torch.sigmoid(self.conv2(out))\n",
|
|||
|
"\n",
|
|||
|
" # 添加空间注意力\n",
|
|||
|
" out = x * out\n",
|
|||
|
" out = self.spatial_attention(out)\n",
|
|||
|
" return out"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SEBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, reduced_dim):\n",
|
|||
|
" super(SEBlock, self).__init__()\n",
|
|||
|
" self.se = nn.Sequential(\n",
|
|||
|
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
|||
|
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
|||
|
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return x * self.se(x)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "c9d176a8-bbf6-4043-ab82-1648a99d772a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def masked_mse_loss(preds, target, mask):\n",
|
|||
|
" loss = (preds - target) ** 2\n",
|
|||
|
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
|||
|
" loss = (loss * (1-mask)).sum() / (1-mask).sum() # 只计算被mask的像素点的损失\n",
|
|||
|
" return loss"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 定义Masked Autoencoder模型\n",
|
|||
|
"class MaskedAutoencoder(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(MaskedAutoencoder, self).__init__()\n",
|
|||
|
" self.encoder = nn.Sequential(\n",
|
|||
|
" Conv(1, 32, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(32,32),\n",
|
|||
|
" \n",
|
|||
|
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" ResidualBlock(64,64),\n",
|
|||
|
" \n",
|
|||
|
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(128, 128)\n",
|
|||
|
" \n",
|
|||
|
" )\n",
|
|||
|
" # self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(32),\n",
|
|||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(16),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
|||
|
" nn.Sigmoid()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" encoded = self.encoder(x)\n",
|
|||
|
" decoded = self.decoder(encoded)\n",
|
|||
|
" return decoded\n",
|
|||
|
"\n",
|
|||
|
"# 实例化模型、损失函数和优化器\n",
|
|||
|
"model = MaskedAutoencoder()\n",
|
|||
|
"criterion = nn.MSELoss()\n",
|
|||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 训练函数\n",
|
|||
|
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
|||
|
" model.train()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 评估函数\n",
|
|||
|
"def evaluate(model, device, data_loader, criterion):\n",
|
|||
|
" model.eval()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" if batch_idx == 8:\n",
|
|||
|
" rand_ind = np.random.randint(0, len(y))\n",
|
|||
|
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 数据准备\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"print(device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
|
|||
|
" return F.conv2d(input, weight, bias, self.stride,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Epoch 1, Train Loss: 1.828955806274876, Val Loss: 0.08777590596408986\n",
|
|||
|
"Epoch 2, Train Loss: 0.06457909727781012, Val Loss: 0.05018303115198861\n",
|
|||
|
"Epoch 3, Train Loss: 0.04399169035006368, Val Loss: 0.03933813378437242\n",
|
|||
|
"Epoch 4, Train Loss: 0.03737294341049839, Val Loss: 0.04090026577017201\n",
|
|||
|
"Epoch 5, Train Loss: 0.03340746862947513, Val Loss: 0.029788545930563515\n",
|
|||
|
"Epoch 6, Train Loss: 0.03127880183240158, Val Loss: 0.02878953230136366\n",
|
|||
|
"Epoch 7, Train Loss: 0.030086695853816837, Val Loss: 0.027378849156979305\n",
|
|||
|
"Epoch 8, Train Loss: 0.02827827861470184, Val Loss: 0.026564865748384105\n",
|
|||
|
"Epoch 9, Train Loss: 0.026973650764014447, Val Loss: 0.026876062349374615\n",
|
|||
|
"Epoch 10, Train Loss: 0.026198443756149145, Val Loss: 0.025235873994542593\n",
|
|||
|
"Epoch 11, Train Loss: 0.025248640154501754, Val Loss: 0.025164278752323407\n",
|
|||
|
"Epoch 12, Train Loss: 0.0246738152373493, Val Loss: 0.02402887423870279\n",
|
|||
|
"Epoch 13, Train Loss: 0.02429686849446673, Val Loss: 0.02467221769490349\n",
|
|||
|
"Epoch 14, Train Loss: 0.023617587716242915, Val Loss: 0.024100169289245535\n",
|
|||
|
"Epoch 15, Train Loss: 0.022902602209535796, Val Loss: 0.023378314227977797\n",
|
|||
|
"Epoch 16, Train Loss: 0.022661644239067746, Val Loss: 0.02472560463556603\n",
|
|||
|
"Epoch 17, Train Loss: 0.02193861959154526, Val Loss: 0.02273730694580434\n",
|
|||
|
"Epoch 18, Train Loss: 0.021775715561075645, Val Loss: 0.022977211248518814\n",
|
|||
|
"Epoch 19, Train Loss: 0.021564541914852325, Val Loss: 0.022313175500551268\n",
|
|||
|
"Epoch 20, Train Loss: 0.0214472935851396, Val Loss: 0.022048505606935984\n",
|
|||
|
"Epoch 21, Train Loss: 0.020810687219340835, Val Loss: 0.02184077285563768\n",
|
|||
|
"Epoch 22, Train Loss: 0.020310772384592647, Val Loss: 0.021513454977478554\n",
|
|||
|
"Epoch 23, Train Loss: 0.02010334756350118, Val Loss: 0.02177375905326943\n",
|
|||
|
"Epoch 24, Train Loss: 0.02025744297795675, Val Loss: 0.02049418441506464\n",
|
|||
|
"Epoch 25, Train Loss: 0.019826160295995657, Val Loss: 0.023377947564890134\n",
|
|||
|
"Epoch 26, Train Loss: 0.019065276574806875, Val Loss: 0.020193443425110917\n",
|
|||
|
"Epoch 27, Train Loss: 0.01881279432745071, Val Loss: 0.01942526154331307\n",
|
|||
|
"Epoch 28, Train Loss: 0.01839842515413841, Val Loss: 0.01973166508572315\n",
|
|||
|
"Epoch 29, Train Loss: 0.018092166516555555, Val Loss: 0.021518220902601286\n",
|
|||
|
"Epoch 30, Train Loss: 0.01789530134942543, Val Loss: 0.0191833000741343\n",
|
|||
|
"Epoch 31, Train Loss: 0.017643442852021546, Val Loss: 0.018857373494599292\n",
|
|||
|
"Epoch 32, Train Loss: 0.017585936365604543, Val Loss: 0.018622038858150367\n",
|
|||
|
"Epoch 33, Train Loss: 0.017121152348513382, Val Loss: 0.018597172726112516\n",
|
|||
|
"Epoch 34, Train Loss: 0.016807572604223872, Val Loss: 0.01907729919054615\n",
|
|||
|
"Epoch 35, Train Loss: 0.0167503119735983, Val Loss: 0.018055098590010137\n",
|
|||
|
"Epoch 36, Train Loss: 0.01674377040839509, Val Loss: 0.017786314029858183\n",
|
|||
|
"Epoch 37, Train Loss: 0.016270555827641888, Val Loss: 0.01821137344770467\n",
|
|||
|
"Epoch 38, Train Loss: 0.016271821564090166, Val Loss: 0.017419732745681236\n",
|
|||
|
"Epoch 39, Train Loss: 0.01634730132180823, Val Loss: 0.017153916838787385\n",
|
|||
|
"Epoch 40, Train Loss: 0.016149515664855545, Val Loss: 0.01720947952968861\n",
|
|||
|
"Epoch 41, Train Loss: 0.015722640304331573, Val Loss: 0.01671495117636314\n",
|
|||
|
"Epoch 42, Train Loss: 0.015584125958882165, Val Loss: 0.016605446490445243\n",
|
|||
|
"Epoch 43, Train Loss: 0.015607581132996168, Val Loss: 0.016551834531128407\n",
|
|||
|
"Epoch 44, Train Loss: 0.015686789721375303, Val Loss: 0.017196020681355426\n",
|
|||
|
"Epoch 45, Train Loss: 0.0152399734099302, Val Loss: 0.016840887422770706\n",
|
|||
|
"Epoch 46, Train Loss: 0.015122933551651296, Val Loss: 0.018965846010998114\n",
|
|||
|
"Epoch 47, Train Loss: 0.015065566115259554, Val Loss: 0.016344470375064594\n",
|
|||
|
"Epoch 48, Train Loss: 0.014854169773766726, Val Loss: 0.016327281677122437\n",
|
|||
|
"Epoch 49, Train Loss: 0.014882152102459845, Val Loss: 0.015837757153186336\n",
|
|||
|
"Epoch 50, Train Loss: 0.014656414190957848, Val Loss: 0.016042638750774645\n",
|
|||
|
"Epoch 51, Train Loss: 0.014637816764200418, Val Loss: 0.015558397091591536\n",
|
|||
|
"Epoch 52, Train Loss: 0.01454300198784214, Val Loss: 0.015685647628756603\n",
|
|||
|
"Epoch 53, Train Loss: 0.014566657712691994, Val Loss: 0.01571561763090874\n",
|
|||
|
"Epoch 54, Train Loss: 0.01434676954522729, Val Loss: 0.015356795890117758\n",
|
|||
|
"Epoch 55, Train Loss: 0.014364799384348557, Val Loss: 0.015472657116713808\n",
|
|||
|
"Epoch 56, Train Loss: 0.014128341450930783, Val Loss: 0.015367844809235922\n",
|
|||
|
"Epoch 57, Train Loss: 0.014267995692878677, Val Loss: 0.016404178910958234\n",
|
|||
|
"Epoch 58, Train Loss: 0.01399662052882773, Val Loss: 0.014956932640008962\n",
|
|||
|
"Epoch 59, Train Loss: 0.013984658806607056, Val Loss: 0.01512009026343698\n",
|
|||
|
"Epoch 60, Train Loss: 0.013917681792278608, Val Loss: 0.01516334629103319\n",
|
|||
|
"Epoch 61, Train Loss: 0.013808810461811614, Val Loss: 0.015075811351746765\n",
|
|||
|
"Epoch 62, Train Loss: 0.014042920544387051, Val Loss: 0.015152243647112776\n",
|
|||
|
"Epoch 63, Train Loss: 0.0136711714971971, Val Loss: 0.014804388201837219\n",
|
|||
|
"Epoch 64, Train Loss: 0.013782783121797457, Val Loss: 0.015533475858618074\n",
|
|||
|
"Epoch 65, Train Loss: 0.013631306383669661, Val Loss: 0.014752479089396213\n",
|
|||
|
"Epoch 66, Train Loss: 0.013644688259186357, Val Loss: 0.01469478735338841\n",
|
|||
|
"Epoch 67, Train Loss: 0.013522711930056793, Val Loss: 0.014726998854372928\n",
|
|||
|
"Epoch 68, Train Loss: 0.01350348583159692, Val Loss: 0.014617940202466588\n",
|
|||
|
"Epoch 69, Train Loss: 0.013397794087644684, Val Loss: 0.014498871904033334\n",
|
|||
|
"Epoch 70, Train Loss: 0.013320690925504888, Val Loss: 0.014324163573224153\n",
|
|||
|
"Epoch 71, Train Loss: 0.013295841332008108, Val Loss: 0.014810262790033177\n",
|
|||
|
"Epoch 72, Train Loss: 0.013151036726943614, Val Loss: 0.014535954208182754\n",
|
|||
|
"Epoch 73, Train Loss: 0.01315474125409597, Val Loss: 0.014322022976937578\n",
|
|||
|
"Epoch 74, Train Loss: 0.013201014497473337, Val Loss: 0.014625799591972757\n",
|
|||
|
"Epoch 75, Train Loss: 0.013166735187155065, Val Loss: 0.01410402478511209\n",
|
|||
|
"Epoch 76, Train Loss: 0.013011173492199496, Val Loss: 0.014279130234647153\n",
|
|||
|
"Epoch 77, Train Loss: 0.012954122741131833, Val Loss: 0.015670507896079947\n",
|
|||
|
"Epoch 78, Train Loss: 0.012964830874202497, Val Loss: 0.013965579806201493\n",
|
|||
|
"Epoch 79, Train Loss: 0.01284469154765874, Val Loss: 0.014020084167149529\n",
|
|||
|
"Epoch 80, Train Loss: 0.01269332727230194, Val Loss: 0.014467649356420361\n",
|
|||
|
"Epoch 81, Train Loss: 0.012900225120779287, Val Loss: 0.014321781124975255\n",
|
|||
|
"Epoch 82, Train Loss: 0.012758908171705795, Val Loss: 0.013745425046602292\n",
|
|||
|
"Epoch 83, Train Loss: 0.01266205709418683, Val Loss: 0.013802579048075784\n",
|
|||
|
"Epoch 84, Train Loss: 0.012549680232128315, Val Loss: 0.013783436657777473\n",
|
|||
|
"Epoch 85, Train Loss: 0.012634162601689545, Val Loss: 0.01444499020867828\n",
|
|||
|
"Epoch 86, Train Loss: 0.012543465024190086, Val Loss: 0.014219797327558495\n",
|
|||
|
"Epoch 87, Train Loss: 0.012490486795234195, Val Loss: 0.013482047425610806\n",
|
|||
|
"Epoch 88, Train Loss: 0.012537837625619327, Val Loss: 0.014496686354057113\n",
|
|||
|
"Epoch 89, Train Loss: 0.012536356080786891, Val Loss: 0.013949389360956292\n",
|
|||
|
"Epoch 90, Train Loss: 0.012426643302601776, Val Loss: 0.013645224328806152\n",
|
|||
|
"Epoch 91, Train Loss: 0.012394862496806531, Val Loss: 0.013617335818707943\n",
|
|||
|
"Epoch 92, Train Loss: 0.012383774110075959, Val Loss: 0.013630805342499889\n",
|
|||
|
"Epoch 93, Train Loss: 0.012307288521749267, Val Loss: 0.013647960637932393\n",
|
|||
|
"Epoch 94, Train Loss: 0.012298794681625218, Val Loss: 0.013733426678870151\n",
|
|||
|
"Epoch 95, Train Loss: 0.012473734824263165, Val Loss: 0.013764488983398942\n",
|
|||
|
"Epoch 96, Train Loss: 0.012222074678515276, Val Loss: 0.013446863671180918\n",
|
|||
|
"Epoch 97, Train Loss: 0.012306330008120344, Val Loss: 0.013694896279319899\n",
|
|||
|
"Epoch 98, Train Loss: 0.012166704374263019, Val Loss: 0.013338639831809855\n",
|
|||
|
"Epoch 99, Train Loss: 0.012187617220447965, Val Loss: 0.01352898025913025\n",
|
|||
|
"Epoch 100, Train Loss: 0.012234464256565252, Val Loss: 0.013427354033980796\n",
|
|||
|
"Epoch 101, Train Loss: 0.012252488267122273, Val Loss: 0.013189904238861887\n",
|
|||
|
"Epoch 102, Train Loss: 0.01208857831692225, Val Loss: 0.013358786896760785\n",
|
|||
|
"Epoch 103, Train Loss: 0.012067412587693718, Val Loss: 0.013412703287356826\n",
|
|||
|
"Epoch 104, Train Loss: 0.011943526178348863, Val Loss: 0.013329273687480991\n",
|
|||
|
"Epoch 105, Train Loss: 0.012186939030457911, Val Loss: 0.013039200052396576\n",
|
|||
|
"Epoch 106, Train Loss: 0.012064487648833739, Val Loss: 0.013328265718448518\n",
|
|||
|
"Epoch 107, Train Loss: 0.01196315302624942, Val Loss: 0.013011285284561898\n",
|
|||
|
"Epoch 108, Train Loss: 0.011942964125175082, Val Loss: 0.013228343076892753\n",
|
|||
|
"Epoch 109, Train Loss: 0.011851983095862363, Val Loss: 0.012941466032791494\n",
|
|||
|
"Epoch 110, Train Loss: 0.011892807039401035, Val Loss: 0.013264400856708413\n",
|
|||
|
"Epoch 111, Train Loss: 0.011915889784747192, Val Loss: 0.01319889353115612\n",
|
|||
|
"Epoch 112, Train Loss: 0.011905829402123484, Val Loss: 0.014149442662610047\n",
|
|||
|
"Epoch 113, Train Loss: 0.011818570989455903, Val Loss: 0.013042371636673586\n",
|
|||
|
"Epoch 114, Train Loss: 0.011752497955140743, Val Loss: 0.01301327784226012\n",
|
|||
|
"Epoch 115, Train Loss: 0.011813209191606375, Val Loss: 0.01286677592225484\n",
|
|||
|
"Epoch 116, Train Loss: 0.011725439075113198, Val Loss: 0.013167357391941904\n",
|
|||
|
"Epoch 117, Train Loss: 0.011835235226721141, Val Loss: 0.01286814648157625\n",
|
|||
|
"Epoch 118, Train Loss: 0.011680879099873835, Val Loss: 0.012708428107313256\n",
|
|||
|
"Epoch 119, Train Loss: 0.01173722647959322, Val Loss: 0.012885383775096331\n",
|
|||
|
"Epoch 120, Train Loss: 0.011672099965343777, Val Loss: 0.012913884747940214\n",
|
|||
|
"Epoch 121, Train Loss: 0.011704605972866693, Val Loss: 0.012728425813143823\n",
|
|||
|
"Epoch 122, Train Loss: 0.011705320578015021, Val Loss: 0.012817327530860012\n",
|
|||
|
"Epoch 123, Train Loss: 0.011644495068492288, Val Loss: 0.012942980015789396\n",
|
|||
|
"Epoch 124, Train Loss: 0.011633442955439171, Val Loss: 0.012936850551015405\n",
|
|||
|
"Epoch 125, Train Loss: 0.011616052921558396, Val Loss: 0.012702107387803384\n",
|
|||
|
"Epoch 126, Train Loss: 0.011607619160652588, Val Loss: 0.012658866025062639\n",
|
|||
|
"Epoch 127, Train Loss: 0.011635440495310788, Val Loss: 0.01304104494681554\n",
|
|||
|
"Epoch 128, Train Loss: 0.01150463111074775, Val Loss: 0.013212839975508291\n",
|
|||
|
"Epoch 129, Train Loss: 0.011585681133293078, Val Loss: 0.01278914052492647\n",
|
|||
|
"Epoch 130, Train Loss: 0.011392400087565896, Val Loss: 0.012796499154794572\n",
|
|||
|
"Epoch 131, Train Loss: 0.011433751801358598, Val Loss: 0.012598757076063264\n",
|
|||
|
"Epoch 132, Train Loss: 0.011496097840921303, Val Loss: 0.01271620902941743\n",
|
|||
|
"Epoch 133, Train Loss: 0.011477598884815804, Val Loss: 0.013398304248034065\n",
|
|||
|
"Epoch 134, Train Loss: 0.011365674946314552, Val Loss: 0.012668505741922713\n",
|
|||
|
"Epoch 135, Train Loss: 0.01142354957696995, Val Loss: 0.013356663286685944\n",
|
|||
|
"Epoch 136, Train Loss: 0.011355750374139497, Val Loss: 0.012617305616167054\n",
|
|||
|
"Epoch 137, Train Loss: 0.011350866257877013, Val Loss: 0.012997348792850971\n",
|
|||
|
"Epoch 138, Train Loss: 0.011416472670617715, Val Loss: 0.012524361819473665\n",
|
|||
|
"Epoch 139, Train Loss: 0.011427981736646458, Val Loss: 0.012654973694415235\n",
|
|||
|
"Epoch 140, Train Loss: 0.011318818902213607, Val Loss: 0.012664613897787102\n",
|
|||
|
"Epoch 141, Train Loss: 0.011320005095247446, Val Loss: 0.012727182441905363\n",
|
|||
|
"Epoch 142, Train Loss: 0.011245375826651827, Val Loss: 0.012474427931010723\n",
|
|||
|
"Epoch 143, Train Loss: 0.011338526420919091, Val Loss: 0.012642348824597117\n",
|
|||
|
"Epoch 144, Train Loss: 0.011243535689207497, Val Loss: 0.012692421772030752\n",
|
|||
|
"Epoch 145, Train Loss: 0.011166462189023289, Val Loss: 0.01263011310861182\n",
|
|||
|
"Epoch 146, Train Loss: 0.011227301243942178, Val Loss: 0.012461379587427894\n",
|
|||
|
"Epoch 147, Train Loss: 0.01119774208364019, Val Loss: 0.012749987918494353\n",
|
|||
|
"Epoch 148, Train Loss: 0.011138954723441001, Val Loss: 0.012676928915194612\n",
|
|||
|
"Epoch 149, Train Loss: 0.011145075226122398, Val Loss: 0.012806226499378681\n",
|
|||
|
"Epoch 150, Train Loss: 0.011238663441737731, Val Loss: 0.012608930385157244\n",
|
|||
|
"Epoch 151, Train Loss: 0.01112103075430724, Val Loss: 0.012799791727604261\n",
|
|||
|
"Epoch 152, Train Loss: 0.01109027168958595, Val Loss: 0.01240885794273953\n",
|
|||
|
"Epoch 153, Train Loss: 0.011098397055721026, Val Loss: 0.012326594039019364\n",
|
|||
|
"Epoch 154, Train Loss: 0.011026590389676356, Val Loss: 0.012310143629672811\n",
|
|||
|
"Epoch 155, Train Loss: 0.011067607804339682, Val Loss: 0.01242478439278567\n",
|
|||
|
"Epoch 156, Train Loss: 0.01105262930215332, Val Loss: 0.01238662200465576\n",
|
|||
|
"Epoch 157, Train Loss: 0.010977347388097117, Val Loss: 0.012163419262575569\n",
|
|||
|
"Epoch 158, Train Loss: 0.010957017552071924, Val Loss: 0.012397716572480415\n",
|
|||
|
"Epoch 159, Train Loss: 0.010956506543396192, Val Loss: 0.012370292931350309\n",
|
|||
|
"Epoch 160, Train Loss: 0.01093887382980133, Val Loss: 0.012291266110294791\n",
|
|||
|
"Test Loss: 0.006885056002065539\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = model.to(device)\n",
|
|||
|
"\n",
|
|||
|
"num_epochs = 160\n",
|
|||
|
"train_losses = list()\n",
|
|||
|
"val_losses = list()\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
|
|||
|
" train_losses.append(train_loss)\n",
|
|||
|
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
|||
|
" val_losses.append(val_loss)\n",
|
|||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
|||
|
"\n",
|
|||
|
"# 测试模型\n",
|
|||
|
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
|||
|
"print(f'Test Loss: {test_loss}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.legend.Legend at 0x7f8d2cf4ebe0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABiAElEQVR4nO3dd3zU9eHH8dfN7EVCEgKBALKJ7O0mGtw4KiIVoVardVP5KdZVreKo1gGVat0VZ4VaVBQRECXsJTJlJYwkrOxxyd3398c3uRgJZOeS8H4+HvdI7nuf+97nE2ny7mdaDMMwEBEREWnGrL6ugIiIiEh1FFhERESk2VNgERERkWZPgUVERESaPQUWERERafYUWERERKTZU2ARERGRZk+BRURERJo9u68r0BA8Hg8HDhwgJCQEi8Xi6+qIiIhIDRiGQW5uLnFxcVitJ+9DaRWB5cCBA8THx/u6GiIiIlIHaWlpdOjQ4aRlWkVgCQkJAcwGh4aG+rg2IiIiUhM5OTnEx8d7/46fTKsILOXDQKGhoQosIiIiLUxNpnNo0q2IiIg0ewosIiIi0uwpsIiIiEiz1yrmsIiISOtjGAalpaW43W5fV0XqwWazYbfb673tiAKLiIg0Oy6Xi4MHD1JQUODrqkgDCAwMpF27djidzjrfQ4FFRESaFY/Hw+7du7HZbMTFxeF0OrUpaAtlGAYul4tDhw6xe/duunXrVu0GcSeiwCIiIs2Ky+XC4/EQHx9PYGCgr6sj9RQQEIDD4WDv3r24XC78/f3rdB9NuhURkWaprv9PXJqfhvhvqX8NIiIi0uwpsIiIiEizp8AiIiLSDCUkJPDCCy80yL0WL16MxWIhKyurQe7nC5p0KyIi0kDOOecc+vfv3yBBY9WqVQQFBdW/Uq2EAstJuEo9PD1/KyVuD3++uBd+dpuvqyQiIi2YYRi43W7s9ur//LZt27YJatRyaEjoJAwMXv9+N++k7KW41OPr6oiInLIMw6DAVeqTh2EYNarjpEmTWLJkCS+++CIWiwWLxcJbb72FxWLhyy+/ZNCgQfj5+fH999+zc+dOLr/8cmJiYggODmbIkCF88803le736yEhi8XCv/71L6644goCAwPp1q0bn332WZ1/pv/5z3/o06cPfn5+JCQk8Nxzz1V6/R//+AfdunXD39+fmJgYrr76au9rn3zyCYmJiQQEBBAZGUlSUhL5+fl1rktNqIflJBy/WIZV6q7ZP1gREWl4hSVuej/8lU8+e/NjyQQ6q/9z+eKLL7J9+3b69u3LY489BsBPP/0EwP3338/f/vY3unTpQkREBGlpaVx00UU88cQT+Pn58c4773DppZeybds2OnbseMLP+Mtf/sIzzzzDs88+y8svv8yECRPYu3cvbdq0qVWb1qxZwzXXXMOjjz7KuHHjWLZsGX/84x+JjIxk0qRJrF69mjvvvJN3332XkSNHcvToUZYuXQrAwYMHGT9+PM888wxXXHEFubm5LF26tMbBrq4UWE7CarVgt1oo9RiUuNXDIiIiJxYWFobT6SQwMJDY2FgAtm7dCsBjjz3G+eef7y3bpk0b+vXr533++OOPM2fOHD777DNuv/32E37GpEmTGD9+PABPPvkkL730EitXrmTMmDG1quvzzz/P6NGjeeihhwDo3r07mzdv5tlnn2XSpEmkpqYSFBTEJZdcQkhICJ06dWLAgAGAGVhKS0u58sor6dSpEwCJiYm1+vy6UGCpht1mBhaXhoRERHwmwGFj82PJPvvs+ho8eHCl53l5eTz66KN8/vnn3gBQWFhIamrqSe9z+umne78PCgoiNDSUzMzMWtdny5YtXH755ZWujRo1ihdeeAG32835559Pp06d6NKlC2PGjGHMmDHeoah+/foxevRoEhMTSU5O5oILLuDqq68mIiKi1vWoDc1hqYbDZv6I1MMiIuI7FouFQKfdJ4+GOMfo16t97r33XubMmcOTTz7J0qVLWb9+PYmJibhcrpPex+FwHPdz8Xga/u9TSEgIa9eu5f3336ddu3Y8/PDD9OvXj6ysLGw2GwsWLODLL7+kd+/evPzyy/To0YPdu3c3eD1+SYGlGs6ywFLq0RwWERE5OafTidvtrrbcDz/8wKRJk7jiiitITEwkNjaWPXv2NH4Fy/Tq1YsffvjhuDp1794dm83sUbLb7SQlJfHMM8+wceNG9uzZw7fffguYQWnUqFH85S9/Yd26dTidTubMmdOoddaQUDXsNjNZa0hIRESqk5CQwIoVK9izZw/BwcEn7P3o1q0bn376KZdeeikWi4WHHnqoUXpKTuRPf/oTQ4YM4fHHH2fcuHGkpKQwY8YM/vGPfwAwb948du3axVlnnUVERARffPEFHo+HHj16sGLFChYuXMgFF1xAdHQ0K1as4NChQ/Tq1atR66welmpoSEhERGrq3nvvxWaz0bt3b9q2bXvCOSnPP/88ERERjBw5kksvvZTk5GQGDhzYZPUcOHAgH330ER988AF9+/bl4Ycf5rHHHmPSpEkAhIeH8+mnn3LeeefRq1cvZs2axfvvv0+fPn0IDQ3lu+++46KLLqJ79+48+OCDPPfcc1x44YWNWmeL0djrkJpATk4OYWFhZGdnExoa2qD3Pu9vi9l1OJ+PbxnBkITaLRsTEZHaKyoqYvfu3XTu3Bl/f39fV0cawIn+m9bm77d6WKpRPiRUoiEhERERn1FgqUb5kJBLQ0IiItJM3XLLLQQHB1f5uOWWW3xdvQahSbfVKA8s2ulWRESaq8cee4x77723ytcaeqqEryiwVMNRPiSkHhYREWmmoqOjiY6O9nU1GpWGhKqhISERERHfU2CphoaEREREfE+BpRoaEhIREfE9BZZqaOM4ERER31NgqUZFYNGQkIiIiK8osFTDriEhERFpIgkJCbzwwgs1KmuxWJg7d26j1qc5UWCphlNDQiIiIj6nwFINDQmJiIj4ngJLNTQkJCLSDBgGuPJ986jhGcGvvvoqcXFxeDyV/15cfvnl/O53v2Pnzp1cfvnlxMTEEBwczJAhQ/jmm28a7Ef0448/ct555xEQEEBkZCQ333wzeXl53tcXL17M0KFDCQoKIjw8nFGjRrF3714ANmzYwLnnnktISAihoaEMGjSI1atXN1jdGoJ2uq2GhoRERJqBkgJ4Ms43n/3AAXAGVVvsN7/5DXfccQeLFi1i9OjRABw9epT58+fzxRdfkJeXx0UXXcQTTzyBn58f77zzDpdeeinbtm2jY8eO9apifn4+ycnJjBgxglWrVpGZmcnvf/97br/9dt566y1KS0sZO3YsN910E++//z4ul4uVK1disZj/p3zChAkMGDCAV155BZvNxvr163E4HPWqU0NTYKmGhoRERKQmIiIiuPDCC5k9e7Y3sHzyySdERUVx7rnnYrVa6devn7f8448/zpw5c/jss8+4/fbb6/XZs2fPpqioiHfeeYegIDNczZgxg0svvZSnn34ah8NBdnY2l1xyCV27dgWgV69e3venpqYydepUevbsCUC3bt3qVZ/GoMBSDQ0JiYg0A45As6fDV59dQxMmTOCmm27iH//4B35+frz33ntce+21WK1W8vLyePTRR/n88885ePAgpaWlFBYWkpqaWu8qbtmyhX79+nnDCsCoUaPweDxs27aNs846i0mTJpGcnMz5559PUlIS11xzDe3atQNgypQp/P73v+fdd98lKSmJ3/zmN95g01zUaQ7LzJkzSUhIwN/fn2HDhrFy5cqTlv/444/p2bMn/v7+JCYm8sUXXxxXZsuWLVx22WWEhYURFBTEkCFDGuQ/Yn1p4zgRkWbAYjGHZXzxKBs2qYlLL70UwzD4/PPPSUtLY+nSpUyYMAGAe++9lzlz5vDkk0+ydOlS1q9fT2JiIi6Xq7F+apW8+eabpKSkMHLkSD788EO6d+/O8uXLAXj00Uf56aefuPjii/n222/p3bs3c+bMaZJ61VStA8uHH37IlClTeOSRR1i7di39+vUjOTmZzMzMKssvW7aM8ePHc+ONN7Ju3TrGjh3L2LFj2bRpk7fMzp07OeOMM+jZsyeLFy9m48a
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"tr_ind = list(range(len(train_losses)))\n",
|
|||
|
"val_ind = list(range(len(val_losses)))\n",
|
|||
|
"plt.plot(train_losses[1:], label='train_loss')\n",
|
|||
|
"plt.plot(val_losses[1:], label='val_loss')\n",
|
|||
|
"plt.legend(loc='best')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"id": "efc96935-bbe0-4ca9-b11a-931cdcfc3bed",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
|||
|
" data_label = data_label[mask_rev==1]\n",
|
|||
|
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
|||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"id": "73a0002b-35d6-4e20-a620-5c8f5cd49296",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"best_mape = 1\n",
|
|||
|
"best_img = None\n",
|
|||
|
"best_mask = None\n",
|
|||
|
"best_recov = None\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
|
|||
|
" if mape < best_mape:\n",
|
|||
|
" best_recov = rev_recon[i][0].numpy()\n",
|
|||
|
" best_mask = used_mask.numpy()\n",
|
|||
|
" best_img = sample[0].numpy()\n",
|
|||
|
" best_mape = mape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "b7265cd0-0660-4707-be3d-0773a38228e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.261634</td>\n",
|
|||
|
" <td>1.801726</td>\n",
|
|||
|
" <td>0.153962</td>\n",
|
|||
|
" <td>0.681159</td>\n",
|
|||
|
" <td>0.891040</td>\n",
|
|||
|
" <td>0.840609</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.572205</td>\n",
|
|||
|
" <td>0.861009</td>\n",
|
|||
|
" <td>0.065723</td>\n",
|
|||
|
" <td>0.249771</td>\n",
|
|||
|
" <td>0.110411</td>\n",
|
|||
|
" <td>0.124012</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.361480</td>\n",
|
|||
|
" <td>0.468918</td>\n",
|
|||
|
" <td>0.047540</td>\n",
|
|||
|
" <td>-2.107971</td>\n",
|
|||
|
" <td>-0.424296</td>\n",
|
|||
|
" <td>-0.070884</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>0.828453</td>\n",
|
|||
|
" <td>1.149391</td>\n",
|
|||
|
" <td>0.111256</td>\n",
|
|||
|
" <td>0.600440</td>\n",
|
|||
|
" <td>0.868937</td>\n",
|
|||
|
" <td>0.797875</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.135805</td>\n",
|
|||
|
" <td>1.621294</td>\n",
|
|||
|
" <td>0.143929</td>\n",
|
|||
|
" <td>0.740937</td>\n",
|
|||
|
" <td>0.922953</td>\n",
|
|||
|
" <td>0.872734</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.557381</td>\n",
|
|||
|
" <td>2.250718</td>\n",
|
|||
|
" <td>0.179544</td>\n",
|
|||
|
" <td>0.835907</td>\n",
|
|||
|
" <td>0.953556</td>\n",
|
|||
|
" <td>0.921983</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>5.733449</td>\n",
|
|||
|
" <td>8.356097</td>\n",
|
|||
|
" <td>1.116946</td>\n",
|
|||
|
" <td>0.985570</td>\n",
|
|||
|
" <td>0.996237</td>\n",
|
|||
|
" <td>0.993398</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 1.261634 1.801726 0.153962 0.681159 0.891040 \n",
|
|||
|
"std 0.572205 0.861009 0.065723 0.249771 0.110411 \n",
|
|||
|
"min 0.361480 0.468918 0.047540 -2.107971 -0.424296 \n",
|
|||
|
"25% 0.828453 1.149391 0.111256 0.600440 0.868937 \n",
|
|||
|
"50% 1.135805 1.621294 0.143929 0.740937 0.922953 \n",
|
|||
|
"75% 1.557381 2.250718 0.179544 0.835907 0.953556 \n",
|
|||
|
"max 5.733449 8.356097 1.116946 0.985570 0.996237 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.840609 \n",
|
|||
|
"std 0.124012 \n",
|
|||
|
"min -0.070884 \n",
|
|||
|
"25% 0.797875 \n",
|
|||
|
"50% 0.872734 \n",
|
|||
|
"75% 0.921983 \n",
|
|||
|
"max 0.993398 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "589e6d80-228d-4e8a-968a-e7477c5e0e24",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.263991</td>\n",
|
|||
|
" <td>1.987788</td>\n",
|
|||
|
" <td>0.153931</td>\n",
|
|||
|
" <td>0.907729</td>\n",
|
|||
|
" <td>0.974785</td>\n",
|
|||
|
" <td>0.953238</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.108035</td>\n",
|
|||
|
" <td>0.209185</td>\n",
|
|||
|
" <td>0.007592</td>\n",
|
|||
|
" <td>0.017280</td>\n",
|
|||
|
" <td>0.005782</td>\n",
|
|||
|
" <td>0.007909</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>1.077143</td>\n",
|
|||
|
" <td>1.658797</td>\n",
|
|||
|
" <td>0.135271</td>\n",
|
|||
|
" <td>0.791607</td>\n",
|
|||
|
" <td>0.933031</td>\n",
|
|||
|
" <td>0.905484</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.208991</td>\n",
|
|||
|
" <td>1.892923</td>\n",
|
|||
|
" <td>0.149006</td>\n",
|
|||
|
" <td>0.901544</td>\n",
|
|||
|
" <td>0.972912</td>\n",
|
|||
|
" <td>0.950092</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.255151</td>\n",
|
|||
|
" <td>1.967265</td>\n",
|
|||
|
" <td>0.153929</td>\n",
|
|||
|
" <td>0.908183</td>\n",
|
|||
|
" <td>0.974939</td>\n",
|
|||
|
" <td>0.953771</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.307615</td>\n",
|
|||
|
" <td>2.079039</td>\n",
|
|||
|
" <td>0.158463</td>\n",
|
|||
|
" <td>0.915666</td>\n",
|
|||
|
" <td>0.977269</td>\n",
|
|||
|
" <td>0.957454</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>1.956845</td>\n",
|
|||
|
" <td>3.320712</td>\n",
|
|||
|
" <td>0.175028</td>\n",
|
|||
|
" <td>0.931715</td>\n",
|
|||
|
" <td>0.981832</td>\n",
|
|||
|
" <td>0.965467</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa r\n",
|
|||
|
"count 75.000000 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
|||
|
"mean 1.263991 1.987788 0.153931 0.907729 0.974785 0.953238\n",
|
|||
|
"std 0.108035 0.209185 0.007592 0.017280 0.005782 0.007909\n",
|
|||
|
"min 1.077143 1.658797 0.135271 0.791607 0.933031 0.905484\n",
|
|||
|
"25% 1.208991 1.892923 0.149006 0.901544 0.972912 0.950092\n",
|
|||
|
"50% 1.255151 1.967265 0.153929 0.908183 0.974939 0.953771\n",
|
|||
|
"75% 1.307615 2.079039 0.158463 0.915666 0.977269 0.957454\n",
|
|||
|
"max 1.956845 3.320712 0.175028 0.931715 0.981832 0.965467"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 67,
|
|||
|
"id": "755abc3e-f4d2-4056-b01b-3fb085f95f19",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# torch.save(model, './models/MAE/final_20.pt')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"id": "782ba792-af34-479d-8b79-f6c544137539",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"model_20 = torch.load('./models/MAE/final_20.pt')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 38,
|
|||
|
"id": "76449691-74b2-43ef-b092-f71cd8116448",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 4, 1)\n",
|
|||
|
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" \n",
|
|||
|
" plt.subplot(1, 4, 2)\n",
|
|||
|
" plt.imshow(masked_feature, cmap='gray')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.subplot(1, 4, 3)\n",
|
|||
|
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.subplot(1, 4, 4)\n",
|
|||
|
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.savefig('./figures/result/20_samples.png', bbox_inches='tight')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 39,
|
|||
|
"id": "82467932-3b38-4d2d-83d9-8d76c4f98a06",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"best_mask_cp = np.where(best_mask == 0, np.nan, best_mask)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 49,
|
|||
|
"id": "6bb568d1-07bd-49c4-9056-9ad2f2dd36a8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAADeCAYAAAAJtZwyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d5wkV3U2jj/33qrqnt7ZoF2FVU5IAgWMMCLIAkwwOX2x4AMvYGQbgUx6sZABY4zhBVsYi/BifkiAsQFbOMBrDAgQyQgJyULIsiRASEKrjNIqbZrprnDv749zzw1V1TM9uzM7YfvsZz/TXV3hVjh1T3jOc4QxxmAsYxnLWMYylrGMZSxjGctYxjKWZSpysQcwlrGMZSxjGctYxjKWsYxlLGMZy67I2LEdy1jGMpaxjGUsYxnLWMYylrEsaxk7tmMZy1jGMpaxjGUsYxnLWMYylmUtY8d2LGMZy1jGMpaxjGUsYxnLWMayrGXs2I5lLGMZy1jGMpaxjGUsYxnLWJa1jB3bsYxlLGMZy1jGMpaxjGUsYxnLspaxYzuWsYxlLGMZy1jGMpaxjGUsY1nWMnZsxzKWsYxlLGMZy1jGMpaxjGUsy1rGju1YxjKWsYxlLGMZy1jGMpaxjGVZSzLqitN//nxACqj9epC9FGJ1BpEoYLIHZCmQKPqfpUBvApACkBJCSkBIQGXAxBra2WA7UJV+50bTn/400B8AZQXkBf2f6gN5AbNjAKO120RI65MrAVTG/cbLRUcBSeLGgUQB3Q6gNaANUJa0f/vZaAOzLYcpNExRwRQaKDSM1jA7SphBCdOvUA1KVIMK+fYCptIw2kCmCtlkCiEFZCohpIDqpUAqIXspkEmIVEGkEqKTQKxKICdSGk+3A/S6NL5el8ZqxyuyHqASunZSAjKh79L+D6+fLu1f7bdRwXr9rfS7kHwB6a8u6V7w3+B+ONElrZ92aX9C0r67a4AkQ7+aQq77bvVKl6hMiVz3MV1up10Y2nepK5S6wtZ8GqXWKE0FbQy0MQCAvCpRmgoSAlIMj7tIIdzfTNE5SvAy2m5dp4dV6QREsB8JCQ37vBmNypTQxmCqGCDXFe6f3oZclxhUJbQBCg1UxqCo4MZY2eWDytjjxWP708d9eui4l4IIIWZfaZFlYmIC55xzDl70ohdh3bp1mJycbKzzrW99C2eeeSZOOOEEnHvuudh7772j3x966CG86lWvwre//e3dNewFk+OOOw6f+cxncMQRR2DDhg1I03SxhzSWscxZqv/3+wAAEczL8vnx+1J/8w1kA2gNU2qol31+EUa6e8QMLqD5lP9be0lDB/MT/TVGN5YbO1dXpoSGRqVLGPuX1zXQ0HY9nsNoztVu7tVGI9cVtNHoVwW0Mcir0q0vgzmDP9f3BQC5pm0KbVAZYFDa7e0UnEoBJXg//jpoQ/OqNkC/pDFNlcbNtVqboddQ2h1pbfx8bU0YPtaHT/nsnO/N7pTlYl/LiZTs6kRZm5RvprWxAbqJgP/N6jLb2tW926C35cvfvpaSvvP504Xzf4fotA50knUZgLNLK03f6zrPf+kS23ta2yZ+P3gdZr1s02OWuj6X2us0QLY770MKGdngbHtrhPswbhs6vkahNbRp2sx1Cd8HRRW/Q04/7tyZN8YcHFuxOoVIFSldN4HoZqRkkz3/4CQJRJYCSdc/KOwECQkkmd3ZGnpIij49MOyUZimMFKQIWUqKJwXQVxBlCQEFE77g+OwRKGIoPHlK6ZVRC694ZQVA07i1hskkhAquuBJAIWBU5Y7jdp1IGCmgiyq+TlJApgpIJYSyY6oMIA2MNBDawPQrGCkhEvuC4RdCWQHSANIqy0RNwdx5WyVjRTM6/l5JQJV03TuTPLDYsQX8+o3rZo+pErtOQteQlVllfj+BE2yCz0okkPYfK6wOfk8knXNeltHEKIVAJhKnZG5ILc5YqIheyWKFE4LGIISEFBIC0r1UKpSAewdLSKGRSInSCCghGmPgRyB8b7Pzq5a+r7isZDAY4GMf+xi+9KUv4W1vextOPfXUxjpPfOIT8cUvfhGTk5NYu3btIoxy98ktt9yCN73pTTj88MNx9tln45hjjlnsIY1lLHMW59ACzmZoSF7AlJoM6tr8uuKkKp3hC5kAGoCSzgBmp5aNXDaC2XEFEBm97NRWpnTzXOi8Amh8LwNjudQ6Mkj9Nj5gXJ8X3TpoOp+pEqiMgQrm78qMNl/SPG5QhJcrOATvgw1fMgcNtPGfw22XsiwX+9qUGtAF6XFmg6vs4CbKPiigY5ZWd93fEigrCCkgUrky7OusZ8dQevs69EXYuQWcTrc5rOHvpqbXbv3gN3aKw20BtOi812HabniAyNvTsU8grS1M+q9bE05tuj+TzBCniuzsyhj3vShHP8bIjq3spfTwZhJIraOYpf5/oiCSIFPIL2qVBArHUSB72DK364EeGJlAJPCXSGvad1nR8bSG4BvJbr+NJkWOrRLxw+VOgrOUlXd6+TtIefnBEUrC8Ho1EcEyoaT7LqSAVBKNkGRlgBSANjCVhkgljVlbReH/fM4IIkF10ZrGzUpmoylRtrXSTccz/MtObnxSsdPLk62QgNDxCzSIQrGY+v6AKFNajxhHpxRsGzqwbc7tTIrZpnDGnm/bb7KGxHcOscv88i0UKBBOns2xjTpZL2WZmJhAlmWN5UVRYGpqareORWuNm266CbfccgtuvvlmbNmyxf2WZRkmJiawevVqHHPMMVBKIQkMZGMMduzYgS1btqAolotpM7NMTU3h6quvxsMPP4wdO3Ys9nDGMpadEvn//f2s6xj7klWv+MJuGNEii7Hzvfsc/IQ4aMyOLoBWpxaguVRDtzq1DQc3MEZD1FQ4t2kYn41pmePD9eqibOBXDUEIDTNuebu61HyfhgHsP8ffl4PM2b4OndrdaV/XbwIvd39rNy+ybY13sleKfc3Cjq+w+1ZNe7MuzhmtZWH5swtW1bK8vE2o96HDS3+Ne19oNB1cKeSMtnRdYud2/oQfJxXcSn58KmOCR2duijyyY5scvJZueDcjJeCIUpb6rCgQO7Tus1U8VjgHg7UPTtEHUAL1Z42hF3RmPgqkNUReADp4ndorIxLZ3LYOmZD24eYMLujw9EaFVQwDVBoC9B3dBEAJURAcIptMo+gWQyRERwFKQKTK3y0lSNHsBMaKasqKXhjaeFhIEAkz01shOCpkgmtnrFMZwohZAfnaFjq+DxxRArwCOgfVZmb5BcnKHWZmwwihXS+cXIWQjagSAGSqi0LnyM0gmmAzmUAKgbwqEfqX7FyWWteixrFC+X357K+20aQQ1pzIAkZqVJWPYFN21sKXbRQqkRK64t981reoTdj+nStcNGlQmkZUerlJkiR4/etfj2c+85mN3/7zP/8Tn/zkJxfFSdRa44tf/CIuueQSt+y5z30u3vCGN+CnP/0pPvaxj+ERj3gE3vnOd2LdunUAgPvvvx9nn302fvGLX+Caa67Z7WMey1jGsmtihhmeK03CeVs1zTE2bMPvkSMbZnBq2R3KvtacWxgHMXRDaDOGhzjAfpvZDc3KtPg5Lain8LcwQ9MmodEbb9e+/+Uis9nXzqlNu/SfZbHsaxkcwyEwEuswsr1WBeMwMHnh4W2pXFn2Nf/u9lHTMbaVgyxtCE1mCSHKbr3Aqa3rPyM46BaGOmF1GKZVV0fP4vp3Qz2x1IaYHGWfNG5eL/4bOrO83lydWmAOji3WWEhrZmGp3Y7H/fPbiR2ltsxe6Bjxw5Fk5FCGeHSXyrdKkSR01lniIQ1aWCWEXw+Is7Chs83fo8+BQqIWOVKCYpQcVWJFLHzkSNQiTVJZ51kJ91fw90BEGHHilwlHkcrSw6a1JkhWoiHYwdQ69u9CpatsHSxHioxGo242hFO4SbR27Y0GUNK9CiOBQcaW63tY6pMrfw4hwG7doGYngUIiVVQ/w7W1WjQVqS4cTSZnVbX+XtpsvBSk6OTsKnRqEbW6Q0sO8mgwuEI3I93LTaSUeMxjHoMXvOAFjd8efvhhTExMuNpcY8xuc3KNMfjFL36BX/ziF27Z/vvvj+n
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x600 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"visualize_rst(best_img, best_mask, best_recov*best_mask_cp, best_img * (1-best_mask) + best_recov*best_mask, '')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"id": "e711bcef-0263-4948-924e-1beb6d38fbf7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'1114', '1952', '2568', '3523', '602'}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
|
|||
|
"find_ex"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 70,
|
|||
|
"id": "addd6ce4-a62d-43b6-a435-d7853ccea91e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 0 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for j in find_ex:\n",
|
|||
|
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
|
|||
|
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
|
|||
|
" mask_rev = 1 - mask\n",
|
|||
|
" img_in = ori * mask_rev / max_pixel_value\n",
|
|||
|
" img_out = model(torch.tensor(img_in.reshape(1, 1, 96, 96), dtype=torch.float32)).detach().cpu().numpy()[0][0] * max_pixel_value\n",
|
|||
|
" out = ori * mask_rev + img_out * mask\n",
|
|||
|
" plt.imshow(out, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off')\n",
|
|||
|
" plt.savefig(f'./test_img/out_fig/{j}-mae_my_out.png', bbox_inches='tight')\n",
|
|||
|
" plt.clf()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "5d51cfc0-3afd-499e-ae97-76f07b0105e7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|