1202 lines
77 KiB
Plaintext
1202 lines
77 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import cv2"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Maximum pixel value in the dataset: 107.49169921875\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"max_pixel_value = 107.49169921875\n",
|
|||
|
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"checkpoint before Generator is OK\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" mask_rates = [10, 20, 30, 40]\n",
|
|||
|
" self.mask_filenames = list()\n",
|
|||
|
" for rate in mask_rates:\n",
|
|||
|
" local_masks = [f\"{f'{mask_dir}{rate}/{f}'}\" for f in os.listdir(f'{mask_dir}{rate}') if f.endswith('.jpg')]\n",
|
|||
|
" self.mask_filenames.extend(local_masks)\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
|||
|
" mask_path = mask_idx\n",
|
|||
|
" select_rate = mask_idx.split('/')[4]\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32), select_rate\n",
|
|||
|
"\n",
|
|||
|
"# 实例化数据集和数据加载器\n",
|
|||
|
"image_dir = './out_mat/96/train/'\n",
|
|||
|
"mask_dir = './out_mat/96/mask/'\n",
|
|||
|
"\n",
|
|||
|
"print(f\"checkpoint before Generator is OK\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "41da7319-9795-441d-bde8-8cf390365099",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"dataset = NO2Dataset(image_dir, mask_dir)\n",
|
|||
|
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
|
|||
|
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
|
|||
|
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
|
|||
|
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 3, 1)\n",
|
|||
|
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Input\")\n",
|
|||
|
" plt.subplot(1, 3, 2)\n",
|
|||
|
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Masked\")\n",
|
|||
|
" plt.subplot(1, 3, 3)\n",
|
|||
|
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Recovery\")\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "645114e8-65a4-4867-b3fe-23395288e855",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Conv(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
|
|||
|
" super(Conv, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ConvBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
|
|||
|
" bias=False):\n",
|
|||
|
" super(ConvBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
|
|||
|
" norm_layer(out_channels),\n",
|
|||
|
" nn.ReLU()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "31ecf247-e98b-4977-a145-782914a042bd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SeparableBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
|
|||
|
" super(SeparableBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
|
|||
|
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
|
|||
|
" # 分离卷积,仅调整空间信息\n",
|
|||
|
" norm_layer(in_channels), # 对输入通道进行归一化\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
|
|||
|
" nn.ReLU6()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ResidualBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
|
|||
|
" super(ResidualBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
|||
|
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
|||
|
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
|
|||
|
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
"\n",
|
|||
|
" # 如果输入和输出通道不一致,进行降采样操作\n",
|
|||
|
" self.downsample = downsample\n",
|
|||
|
" if in_channels != out_channels or stride != 1:\n",
|
|||
|
" self.downsample = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
|
|||
|
" nn.BatchNorm2d(out_channels)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" identity = x\n",
|
|||
|
" if self.downsample is not None:\n",
|
|||
|
" identity = self.downsample(x)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv1(x)\n",
|
|||
|
" out = self.bn1(out)\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv2(out)\n",
|
|||
|
" out = self.bn2(out)\n",
|
|||
|
"\n",
|
|||
|
" out += identity\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
" return out\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Mlp(nn.Module):\n",
|
|||
|
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" out_features = out_features or in_features\n",
|
|||
|
" hidden_features = hidden_features or in_features\n",
|
|||
|
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
|||
|
"\n",
|
|||
|
" self.act = act_layer()\n",
|
|||
|
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
|||
|
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.fc1(x)\n",
|
|||
|
" x = self.act(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" x = self.fc2(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" return x"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class MultiHeadAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
|
|||
|
" super(MultiHeadAttentionBlock, self).__init__()\n",
|
|||
|
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
|
|||
|
" self.norm = nn.LayerNorm(embed_dim)\n",
|
|||
|
" self.dropout = nn.Dropout(dropout)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
|
|||
|
" B, C, H, W = x.shape\n",
|
|||
|
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
|
|||
|
"\n",
|
|||
|
" # Apply multihead attention\n",
|
|||
|
" attn_output, _ = self.attention(x, x, x)\n",
|
|||
|
"\n",
|
|||
|
" # Apply normalization and dropout\n",
|
|||
|
" attn_output = self.norm(attn_output)\n",
|
|||
|
" attn_output = self.dropout(attn_output)\n",
|
|||
|
"\n",
|
|||
|
" # Reshape back to (B, C, H, W)\n",
|
|||
|
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
|
|||
|
"\n",
|
|||
|
" return attn_output"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SpatialAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(SpatialAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x): #(B, 64, H, W)\n",
|
|||
|
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
|
|||
|
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
|
|||
|
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
|
|||
|
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
|
|||
|
" return x * out #(B, C, H, W)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class DecoderAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels):\n",
|
|||
|
" super(DecoderAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
|
|||
|
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
|
|||
|
" self.spatial_attention = SpatialAttentionBlock()\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # 通道注意力\n",
|
|||
|
" b, c, h, w = x.size()\n",
|
|||
|
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
|
|||
|
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
|
|||
|
"\n",
|
|||
|
" avg_out = self.conv1(avg_pool)\n",
|
|||
|
" max_out = self.conv1(max_pool)\n",
|
|||
|
"\n",
|
|||
|
" out = avg_out + max_out\n",
|
|||
|
" out = torch.sigmoid(self.conv2(out))\n",
|
|||
|
"\n",
|
|||
|
" # 添加空间注意力\n",
|
|||
|
" out = x * out\n",
|
|||
|
" out = self.spatial_attention(out)\n",
|
|||
|
" return out"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SEBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, reduced_dim):\n",
|
|||
|
" super(SEBlock, self).__init__()\n",
|
|||
|
" self.se = nn.Sequential(\n",
|
|||
|
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
|||
|
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
|||
|
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return x * self.se(x)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 定义Masked Autoencoder模型\n",
|
|||
|
"class MaskedAutoencoder(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(MaskedAutoencoder, self).__init__()\n",
|
|||
|
" self.encoder = nn.Sequential(\n",
|
|||
|
" Conv(1, 32, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(32,32),\n",
|
|||
|
" \n",
|
|||
|
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" ResidualBlock(64,64),\n",
|
|||
|
" \n",
|
|||
|
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(128, 128)\n",
|
|||
|
" )\n",
|
|||
|
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(32),\n",
|
|||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(16),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
|||
|
" nn.Sigmoid()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" encoded = self.encoder(x)\n",
|
|||
|
" decoded = self.decoder(encoded)\n",
|
|||
|
" return decoded\n",
|
|||
|
"\n",
|
|||
|
"# 实例化模型、损失函数和优化器\n",
|
|||
|
"model = MaskedAutoencoder()\n",
|
|||
|
"criterion = nn.MSELoss()\n",
|
|||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "e9c804e0-6f5c-40a7-aba7-a03a496cf427",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def masked_mse_loss(preds, target, mask):\n",
|
|||
|
" loss = (preds - target) ** 2\n",
|
|||
|
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
|||
|
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
|||
|
" return loss"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 训练函数\n",
|
|||
|
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
|||
|
" model.train()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" miss_counts = list()\n",
|
|||
|
" for batch_idx, (X, y, mask, miss_rate) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" miss_counts.append(miss_rate)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" # loss = criterion(reconstructed, y)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1), miss_counts"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 评估函数\n",
|
|||
|
"def evaluate(model, device, data_loader, criterion):\n",
|
|||
|
" model.eval()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" miss_counts = list()\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask, miss_rate) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" miss_counts.append(miss_rate)\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" if batch_idx == 8:\n",
|
|||
|
" rand_ind = np.random.randint(0, len(y))\n",
|
|||
|
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1), miss_counts"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 数据准备\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"print(device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
|
|||
|
" return F.conv2d(input, weight, bias, self.stride,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Epoch 1, Train Loss: 3.759178739843186, Val Loss: 0.1379857260122228\n",
|
|||
|
"Epoch 2, Train Loss: 0.09902132054764118, Val Loss: 0.066096671370428\n",
|
|||
|
"Epoch 3, Train Loss: 0.060244543255088434, Val Loss: 0.05034376319442222\n",
|
|||
|
"Epoch 4, Train Loss: 0.04942069527956002, Val Loss: 0.04460687851950304\n",
|
|||
|
"Epoch 5, Train Loss: 0.04382758207940029, Val Loss: 0.0369152329417307\n",
|
|||
|
"Epoch 6, Train Loss: 0.03961431584432365, Val Loss: 0.033898868973353015\n",
|
|||
|
"Epoch 7, Train Loss: 0.03653587933861468, Val Loss: 0.03190647060079361\n",
|
|||
|
"Epoch 8, Train Loss: 0.03421460006956421, Val Loss: 0.030460054360663714\n",
|
|||
|
"Epoch 9, Train Loss: 0.03215051376434604, Val Loss: 0.03062500929765737\n",
|
|||
|
"Epoch 10, Train Loss: 0.031739671104119724, Val Loss: 0.029085035394154378\n",
|
|||
|
"Epoch 11, Train Loss: 0.030470874753188004, Val Loss: 0.03185694292187691\n",
|
|||
|
"Epoch 12, Train Loss: 0.029636846623566162, Val Loss: 0.029310374951629498\n",
|
|||
|
"Epoch 13, Train Loss: 0.028289151542851228, Val Loss: 0.02720484949314772\n",
|
|||
|
"Epoch 14, Train Loss: 0.027910822102327666, Val Loss: 0.028894296833383504\n",
|
|||
|
"Epoch 15, Train Loss: 0.027092363841332602, Val Loss: 0.02946079163742599\n",
|
|||
|
"Epoch 16, Train Loss: 0.025776214282692334, Val Loss: 0.024672900368251018\n",
|
|||
|
"Epoch 17, Train Loss: 0.025803192848402063, Val Loss: 0.02488229790730263\n",
|
|||
|
"Epoch 18, Train Loss: 0.025352436108915716, Val Loss: 0.02426056825180552\n",
|
|||
|
"Epoch 19, Train Loss: 0.024724755284675, Val Loss: 0.023613420885000656\n",
|
|||
|
"Epoch 20, Train Loss: 0.02373662724663196, Val Loss: 0.023868454147630662\n",
|
|||
|
"Epoch 21, Train Loss: 0.023606173005668026, Val Loss: 0.022293920976234907\n",
|
|||
|
"Epoch 22, Train Loss: 0.02291965261814697, Val Loss: 0.0231649547036904\n",
|
|||
|
"Epoch 23, Train Loss: 0.022957429811180208, Val Loss: 0.022116250789432385\n",
|
|||
|
"Epoch 24, Train Loss: 0.022525311819763416, Val Loss: 0.02422845282994989\n",
|
|||
|
"Epoch 25, Train Loss: 0.02231395777101009, Val Loss: 0.02212312592388089\n",
|
|||
|
"Epoch 26, Train Loss: 0.02209535693420035, Val Loss: 0.02158943160589951\n",
|
|||
|
"Epoch 27, Train Loss: 0.021671999831857722, Val Loss: 0.022256974825885758\n",
|
|||
|
"Epoch 28, Train Loss: 0.021378441671417517, Val Loss: 0.021293755787522045\n",
|
|||
|
"Epoch 29, Train Loss: 0.021532584222381194, Val Loss: 0.021740848698945187\n",
|
|||
|
"Epoch 30, Train Loss: 0.02089789963625906, Val Loss: 0.022172707369300857\n",
|
|||
|
"Epoch 31, Train Loss: 0.020911543732553578, Val Loss: 0.020904658445671423\n",
|
|||
|
"Epoch 32, Train Loss: 0.020589363574090472, Val Loss: 0.021264061137144245\n",
|
|||
|
"Epoch 33, Train Loss: 0.02011841800037112, Val Loss: 0.022388043521500346\n",
|
|||
|
"Epoch 34, Train Loss: 0.020350060138281025, Val Loss: 0.020872680664952122\n",
|
|||
|
"Epoch 35, Train Loss: 0.019910728570038193, Val Loss: 0.02008631487668895\n",
|
|||
|
"Epoch 36, Train Loss: 0.01966284622291201, Val Loss: 0.02018301992385245\n",
|
|||
|
"Epoch 37, Train Loss: 0.019478668659283785, Val Loss: 0.020117887351383913\n",
|
|||
|
"Epoch 38, Train Loss: 0.019168558606262983, Val Loss: 0.020217864148652377\n",
|
|||
|
"Epoch 39, Train Loss: 0.018900538525102956, Val Loss: 0.019784750694881625\n",
|
|||
|
"Epoch 40, Train Loss: 0.019068713380139695, Val Loss: 0.020406662806201337\n",
|
|||
|
"Epoch 41, Train Loss: 0.01922704772488994, Val Loss: 0.019463480088804195\n",
|
|||
|
"Epoch 42, Train Loss: 0.018683298484257392, Val Loss: 0.019570431866641366\n",
|
|||
|
"Epoch 43, Train Loss: 0.018411033715535863, Val Loss: 0.019696261789371717\n",
|
|||
|
"Epoch 44, Train Loss: 0.018502752826901142, Val Loss: 0.0193116083574384\n",
|
|||
|
"Epoch 45, Train Loss: 0.01851825592772028, Val Loss: 0.021103291230192826\n",
|
|||
|
"Epoch 46, Train Loss: 0.01816361720125641, Val Loss: 0.020114433075954664\n",
|
|||
|
"Epoch 47, Train Loss: 0.018051497934555464, Val Loss: 0.020221358179045256\n",
|
|||
|
"Epoch 48, Train Loss: 0.01811225383885597, Val Loss: 0.01961083782475386\n",
|
|||
|
"Epoch 49, Train Loss: 0.017867776890548224, Val Loss: 0.018948225665893128\n",
|
|||
|
"Epoch 50, Train Loss: 0.01761771424152135, Val Loss: 0.01865902607009482\n",
|
|||
|
"Epoch 51, Train Loss: 0.01793021524467608, Val Loss: 0.018359918592136298\n",
|
|||
|
"Epoch 52, Train Loss: 0.017610817650805393, Val Loss: 0.018650228838756014\n",
|
|||
|
"Epoch 53, Train Loss: 0.017737194443451305, Val Loss: 0.018363466583637158\n",
|
|||
|
"Epoch 54, Train Loss: 0.017543190524302886, Val Loss: 0.019013355055184505\n",
|
|||
|
"Epoch 55, Train Loss: 0.01778105637236859, Val Loss: 0.018212769875553116\n",
|
|||
|
"Epoch 56, Train Loss: 0.017451271454861576, Val Loss: 0.018818481644587732\n",
|
|||
|
"Epoch 57, Train Loss: 0.017273589150989026, Val Loss: 0.01801557773585195\n",
|
|||
|
"Epoch 58, Train Loss: 0.01728663447816549, Val Loss: 0.01771288837737112\n",
|
|||
|
"Epoch 59, Train Loss: 0.017209396768878237, Val Loss: 0.018658861782012592\n",
|
|||
|
"Epoch 60, Train Loss: 0.017015971490694434, Val Loss: 0.01875163140748419\n",
|
|||
|
"Epoch 61, Train Loss: 0.01697286305744112, Val Loss: 0.01831459281827087\n",
|
|||
|
"Epoch 62, Train Loss: 0.01689975440466518, Val Loss: 0.018071504671182206\n",
|
|||
|
"Epoch 63, Train Loss: 0.016585711293974133, Val Loss: 0.01783462390025605\n",
|
|||
|
"Epoch 64, Train Loss: 0.016933080276839756, Val Loss: 0.018715852857636873\n",
|
|||
|
"Epoch 65, Train Loss: 0.016899143777894633, Val Loss: 0.019256604974394412\n",
|
|||
|
"Epoch 66, Train Loss: 0.016631374423031173, Val Loss: 0.018876284666693034\n",
|
|||
|
"Epoch 67, Train Loss: 0.016569798094839855, Val Loss: 0.018378769520169765\n",
|
|||
|
"Epoch 68, Train Loss: 0.016539438030544366, Val Loss: 0.018459608500350767\n",
|
|||
|
"Epoch 69, Train Loss: 0.01645555520323261, Val Loss: 0.01851357322241833\n",
|
|||
|
"Epoch 70, Train Loss: 0.01667448620726332, Val Loss: 0.017527391814362647\n",
|
|||
|
"Epoch 71, Train Loss: 0.01630861950708491, Val Loss: 0.01862382395331984\n",
|
|||
|
"Epoch 72, Train Loss: 0.016292595119621053, Val Loss: 0.01898773131308271\n",
|
|||
|
"Epoch 73, Train Loss: 0.016312802497867904, Val Loss: 0.017515668033886312\n",
|
|||
|
"Epoch 74, Train Loss: 0.01634560936714331, Val Loss: 0.017603496631690814\n",
|
|||
|
"Epoch 75, Train Loss: 0.016150180214757556, Val Loss: 0.0177685193606277\n",
|
|||
|
"Epoch 76, Train Loss: 0.016183897565479912, Val Loss: 0.01790037954142734\n",
|
|||
|
"Epoch 77, Train Loss: 0.016441928089092794, Val Loss: 0.0177356356671497\n",
|
|||
|
"Epoch 78, Train Loss: 0.016029272553773875, Val Loss: 0.01720855048676925\n",
|
|||
|
"Epoch 79, Train Loss: 0.015830894611312443, Val Loss: 0.017439508657735674\n",
|
|||
|
"Epoch 80, Train Loss: 0.015893817865891318, Val Loss: 0.017185933985260884\n",
|
|||
|
"Epoch 81, Train Loss: 0.01587246311160081, Val Loss: 0.017182132229208946\n",
|
|||
|
"Epoch 82, Train Loss: 0.015938340017848322, Val Loss: 0.01732705053942862\n",
|
|||
|
"Epoch 83, Train Loss: 0.015770130625894767, Val Loss: 0.01730423010607709\n",
|
|||
|
"Epoch 84, Train Loss: 0.015774958316931886, Val Loss: 0.01693567380642713\n",
|
|||
|
"Epoch 85, Train Loss: 0.015681640634928166, Val Loss: 0.01731172299929964\n",
|
|||
|
"Epoch 86, Train Loss: 0.015522310860080725, Val Loss: 0.01708351758155805\n",
|
|||
|
"Epoch 87, Train Loss: 0.015825702162664473, Val Loss: 0.01767030195680572\n",
|
|||
|
"Epoch 88, Train Loss: 0.015465608916053789, Val Loss: 0.0169600204689734\n",
|
|||
|
"Epoch 89, Train Loss: 0.015413585239263812, Val Loss: 0.016799337550330518\n",
|
|||
|
"Epoch 90, Train Loss: 0.015661140533975153, Val Loss: 0.017084516890680614\n",
|
|||
|
"Epoch 91, Train Loss: 0.015471032805045684, Val Loss: 0.017242409135979502\n",
|
|||
|
"Epoch 92, Train Loss: 0.015306838647725337, Val Loss: 0.016721693103882804\n",
|
|||
|
"Epoch 93, Train Loss: 0.01516885641721661, Val Loss: 0.01838143560479381\n",
|
|||
|
"Epoch 94, Train Loss: 0.015182504183100314, Val Loss: 0.017020777451680666\n",
|
|||
|
"Epoch 95, Train Loss: 0.01524644939264541, Val Loss: 0.01649292297105291\n",
|
|||
|
"Epoch 96, Train Loss: 0.015118425159434382, Val Loss: 0.017190173087613798\n",
|
|||
|
"Epoch 97, Train Loss: 0.015101557916128322, Val Loss: 0.016093250461367527\n",
|
|||
|
"Epoch 98, Train Loss: 0.01503138992775, Val Loss: 0.016338717831826922\n",
|
|||
|
"Epoch 99, Train Loss: 0.015078757967550361, Val Loss: 0.016478037350435754\n",
|
|||
|
"Epoch 100, Train Loss: 0.014985626251503611, Val Loss: 0.01633207424919107\n",
|
|||
|
"Epoch 101, Train Loss: 0.014759322786570023, Val Loss: 0.01683194490511026\n",
|
|||
|
"Epoch 102, Train Loss: 0.014856852341496774, Val Loss: 0.016027600129148854\n",
|
|||
|
"Epoch 103, Train Loss: 0.014765939864655289, Val Loss: 0.016350745793376396\n",
|
|||
|
"Epoch 104, Train Loss: 0.01478316887330852, Val Loss: 0.016033862258738547\n",
|
|||
|
"Epoch 105, Train Loss: 0.014725807853684755, Val Loss: 0.015603851276769568\n",
|
|||
|
"Epoch 106, Train Loss: 0.014806732724746021, Val Loss: 0.015736672651967896\n",
|
|||
|
"Epoch 107, Train Loss: 0.014543344516253642, Val Loss: 0.015925641963953404\n",
|
|||
|
"Epoch 108, Train Loss: 0.014782626121683696, Val Loss: 0.016552887453850525\n",
|
|||
|
"Epoch 109, Train Loss: 0.014329457426060472, Val Loss: 0.01566976616020078\n",
|
|||
|
"Epoch 110, Train Loss: 0.014614671502155408, Val Loss: 0.016271342245389276\n",
|
|||
|
"Epoch 111, Train Loss: 0.014544662480291567, Val Loss: 0.01549402935736215\n",
|
|||
|
"Epoch 112, Train Loss: 0.01446673739478705, Val Loss: 0.015960639662373422\n",
|
|||
|
"Epoch 113, Train Loss: 0.014492520645849015, Val Loss: 0.015249295007270663\n",
|
|||
|
"Epoch 114, Train Loss: 0.014440985597028402, Val Loss: 0.01671606713711326\n",
|
|||
|
"Epoch 115, Train Loss: 0.014369557464593336, Val Loss: 0.016106587264742424\n",
|
|||
|
"Epoch 116, Train Loss: 0.01432103816972395, Val Loss: 0.015263923374352171\n",
|
|||
|
"Epoch 117, Train Loss: 0.014226941607945987, Val Loss: 0.015028324297893404\n",
|
|||
|
"Epoch 118, Train Loss: 0.01423997960485625, Val Loss: 0.014743029529145404\n",
|
|||
|
"Epoch 119, Train Loss: 0.014351020645100677, Val Loss: 0.01581134552608675\n",
|
|||
|
"Epoch 120, Train Loss: 0.014202667741131696, Val Loss: 0.015378265266320598\n",
|
|||
|
"Epoch 121, Train Loss: 0.013911791727321142, Val Loss: 0.01487369868737548\n",
|
|||
|
"Epoch 122, Train Loss: 0.013906272411186017, Val Loss: 0.01551159023682573\n",
|
|||
|
"Epoch 123, Train Loss: 0.013943794016329723, Val Loss: 0.015357211718697156\n",
|
|||
|
"Epoch 124, Train Loss: 0.01389588224233694, Val Loss: 0.015303193772239472\n",
|
|||
|
"Epoch 125, Train Loss: 0.014016644986854359, Val Loss: 0.014799274629287755\n",
|
|||
|
"Epoch 126, Train Loss: 0.013944415422379258, Val Loss: 0.014797273328277603\n",
|
|||
|
"Epoch 127, Train Loss: 0.013957360926480812, Val Loss: 0.014890457517397938\n",
|
|||
|
"Epoch 128, Train Loss: 0.013801010133939211, Val Loss: 0.015028401750570802\n",
|
|||
|
"Epoch 129, Train Loss: 0.013806760874821952, Val Loss: 0.016021162049094245\n",
|
|||
|
"Epoch 130, Train Loss: 0.014049455859925616, Val Loss: 0.015217644565585834\n",
|
|||
|
"Epoch 131, Train Loss: 0.013769885206497029, Val Loss: 0.015085379940582745\n",
|
|||
|
"Epoch 132, Train Loss: 0.013684874973103903, Val Loss: 0.014550712029102133\n",
|
|||
|
"Epoch 133, Train Loss: 0.013696547392666625, Val Loss: 0.014757407259251645\n",
|
|||
|
"Epoch 134, Train Loss: 0.01369966242827796, Val Loss: 0.014638274657859732\n",
|
|||
|
"Epoch 135, Train Loss: 0.013533816318602511, Val Loss: 0.014734907506673193\n",
|
|||
|
"Epoch 136, Train Loss: 0.013603145677738926, Val Loss: 0.014580759831440093\n",
|
|||
|
"Epoch 137, Train Loss: 0.013541612814238482, Val Loss: 0.01570955854354065\n",
|
|||
|
"Epoch 138, Train Loss: 0.013723757467789656, Val Loss: 0.016205344780056335\n",
|
|||
|
"Epoch 139, Train Loss: 0.013546007516031916, Val Loss: 0.0152104031572591\n",
|
|||
|
"Epoch 140, Train Loss: 0.013532601969771123, Val Loss: 0.015342667142846692\n",
|
|||
|
"Epoch 141, Train Loss: 0.013450533512569786, Val Loss: 0.014644546336980898\n",
|
|||
|
"Epoch 142, Train Loss: 0.013607010434706959, Val Loss: 0.014687455078559135\n",
|
|||
|
"Epoch 143, Train Loss: 0.013542775672962934, Val Loss: 0.014521264234807953\n",
|
|||
|
"Epoch 144, Train Loss: 0.013417973114026078, Val Loss: 0.014601941859877822\n",
|
|||
|
"Epoch 145, Train Loss: 0.013331704906691489, Val Loss: 0.01485029947179467\n",
|
|||
|
"Epoch 146, Train Loss: 0.013418046530318316, Val Loss: 0.014630124362102195\n",
|
|||
|
"Epoch 147, Train Loss: 0.013351045589020663, Val Loss: 0.01494142015589707\n",
|
|||
|
"Epoch 148, Train Loss: 0.013260266191045348, Val Loss: 0.015414885175761893\n",
|
|||
|
"Epoch 149, Train Loss: 0.013240087648149598, Val Loss: 0.014419331771335494\n",
|
|||
|
"Epoch 150, Train Loss: 0.01334052808297593, Val Loss: 0.01435606328965123\n",
|
|||
|
"Test Loss: 0.008245683658557634\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = model.to(device)\n",
|
|||
|
"\n",
|
|||
|
"num_epochs = 150\n",
|
|||
|
"train_losses = list()\n",
|
|||
|
"val_losses = list()\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" train_loss, train_counts = train_epoch(model, device, dataloader, criterion, optimizer)\n",
|
|||
|
" train_losses.append(train_loss)\n",
|
|||
|
" val_loss, val_counts = evaluate(model, device, val_loader, criterion)\n",
|
|||
|
" val_losses.append(val_loss)\n",
|
|||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
|||
|
"\n",
|
|||
|
"# 测试模型\n",
|
|||
|
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
|||
|
"print(f'Test Loss: {test_loss[0]}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.legend.Legend at 0x7fb5a9a95fa0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABbIElEQVR4nO3dd3zV1eH/8dfd2RsSwgooMiMgS0BFBQ2CA6yKlAqo1WrFRUsV66p+lboobmr7s9bWgQuLC0VERLYsRZYyQhhJgJA97s29n98fn+SGSCA3IckN4f18PO6D5HPP/dxzwrhvzrQYhmEgIiIi0oxZg10BERERkdoosIiIiEizp8AiIiIizZ4Ci4iIiDR7CiwiIiLS7CmwiIiISLOnwCIiIiLNngKLiIiINHv2YFegIfh8Pvbt20dkZCQWiyXY1REREZEAGIZBQUEBycnJWK3H70NpEYFl3759tG/fPtjVEBERkXrIyMigXbt2xy3TIgJLZGQkYDY4KioqyLURERGRQOTn59O+fXv/5/jxtIjAUjkMFBUVpcAiIiJykglkOocm3YqIiEizp8AiIiIizZ4Ci4iIiDR7LWIOi4iItDyGYVBeXo7X6w12VeQE2Gw27Hb7CW87osAiIiLNjtvtZv/+/RQXFwe7KtIAwsLCaNOmDU6ns973qFdgefHFF3nqqafIzMykd+/ePP/88wwcOLDGsj/++CMPPvgga9asIT09nb/97W/cddddJ3RPERFpuXw+Hzt37sRms5GcnIzT6dSmoCcpwzBwu90cOHCAnTt30qVLl1o3iDuWOgeWOXPmMHXqVGbPns2gQYOYNWsWaWlpbN26ldatWx9Vvri4mM6dO3P11Vdz9913N8g9RUSk5XK73fh8Ptq3b09YWFiwqyMnKDQ0FIfDQXp6Om63m5CQkHrdp84xZ+bMmdx0001cf/319OjRg9mzZxMWFsarr75aY/kBAwbw1FNPce211+JyuRrkniIi0vLV93/i0vw0xO9lne7gdrtZs2YNI0aMqFaJESNGsHz58npVoDHuKSIiIi1LnYaEDh48iNfrJTExsdr1xMREtmzZUq8K1OeeZWVllJWV+b/Pz8+v13uLiIjIyeGk7G+bMWMG0dHR/ocOPhQRkZYmJSWFWbNmNci9vv76aywWC7m5uQ1yv2CoU2BJSEjAZrORlZVV7XpWVhZJSUn1qkB97jl9+nTy8vL8j4yMjHq9t4iISEM6//zza1wJWx+rV6/m5ptvbpB7tQR1CixOp5N+/fqxcOFC/zWfz8fChQsZPHhwvSpQn3u6XC7/QYeNeeChu9zHIx9t4sH/baSsXBsXiYjIiancDC8QrVq10iqpI9R5SGjq1Kn84x//4N///jebN2/m1ltvpaioiOuvvx6AiRMnMn36dH95t9vN+vXrWb9+PW63m71797J+/Xp+/vnngO8ZLAYGry7dyevL0yn1+IJaFxGRU5lhGBS7y4PyMAwjoDpOnjyZxYsX8+yzz2KxWLBYLLz22mtYLBY+++wz+vXrh8vl4ttvv2X79u1cccUVJCYmEhERwYABA/jyyy+r3e+XQ0IWi4V//vOfjB07lrCwMLp06cK8efPq/TN9//336dmzJy6Xi5SUFJ555plqz7/00kt06dKFkJAQEhMTueqqq/zPvffee6SmphIaGkp8fDwjRoygqKio3nUJRJ33YRk3bhwHDhzgwQcfJDMzkz59+jB//nz/pNndu3dXW760b98++vbt6//+6aef5umnn2bYsGF8/fXXAd0zWBxHtKPcq8AiIhIsJR4vPR78PCjvvemRNMKctX9cPvvss2zbto1evXrxyCOPAObmqQD33nsvTz/9NJ07dyY2NpaMjAxGjRrFY489hsvl4vXXX+eyyy5j69atdOjQ4Zjv8Ze//IUnn3ySp556iueff54JEyaQnp5OXFxcndq0Zs0arrnmGh5++GHGjRvHsmXL+P3vf098fDyTJ0/mu+++44477uA///kPQ4YMIScnhyVLlgCwf/9+xo8fz5NPPsnYsWMpKChgyZIlAQe7+qrXTrdTpkxhypQpNT5XGUIqpaSkBNSI490zWKxWCzarBa/PoNzXuL8RIiJycouOjsbpdBIWFuafg1m52vWRRx7hoosu8peNi4ujd+/e/u8fffRR5s6dy7x58477WTh58mTGjx8PwOOPP85zzz3HqlWrGDlyZJ3qOnPmTIYPH84DDzwAwBlnnMGmTZt46qmnmDx5Mrt37yY8PJxLL72UyMhIOnbs6O982L9/P+Xl5Vx55ZV07NgRgNTU1Dq9f33oLKFa2CsCi0c9LCIiQRPqsLHpkbSgvfeJ6t+/f7XvCwsLefjhh/nkk0/8AaCkpITdu3cf9z5nnnmm/+vw8HCioqLIzs6uc302b97MFVdcUe3a0KFDmTVrFl6vl4suuoiOHTvSuXNnRo4cyciRI/1DUb1792b48OGkpqaSlpbGxRdfzFVXXUVsbGyd61EXJ+Wy5qbksJk/onKvelhERILFYrEQ5rQH5dEQ5xiFh4dX+/6Pf/wjc+fO5fHHH2fJkiWsX7+e1NRU3G73ce/jcDiO+rn4fA3/H+rIyEjWrl3LW2+9RZs2bXjwwQfp3bs3ubm52Gw2FixYwGeffUaPHj14/vnn6dq1Kzt37mzwehxJgaUWdpv5B7W8Ef5AiIhIy+J0OvF6a19VunTpUiZPnszYsWNJTU0lKSmJXbt2NX4FK3Tv3p2lS5ceVaczzjgDm83sUbLb7YwYMYInn3yS77//nl27dvHVV18BZlAaOnQof/nLX1i3bh1Op5O5c+c2ap01JFQLe8XEW496WEREpBYpKSmsXLmSXbt2ERERcczejy5duvDBBx9w2WWXYbFYeOCBBxqlp+RY/vCHPzBgwAAeffRRxo0bx/Lly3nhhRd46aWXAPj444/ZsWMH5513HrGxsXz66af4fD66du3KypUrWbhwIRdffDGtW7dm5cqVHDhwgO7duzdqndXDUgtHZQ+LAouIiNTij3/8IzabjR49etCqVatjzkmZOXMmsbGxDBkyhMsuu4y0tDTOOuusJqvnWWedxTvvvMPbb79Nr169ePDBB3nkkUeYPHkyADExMXzwwQdceOGFdO/endmzZ/PWW2/Rs2dPoqKi+Oabbxg1ahRnnHEG999/P8888wyXXHJJo9bZYjT2OqQmkJ+fT3R0NHl5eQ2+idy5T35FRk4J7986hH4dG3dCkYiIQGlpKTt37qRTp06EhIQEuzrSAI71e1qXz2/1sNSiatKt5rCIiIgEiwJLLSo3j9M+LCIi0lzdcsstRERE1Pi45ZZbgl29BqFJt7WoXCWkfVhERKS5euSRR/jjH/9Y43ONdd5eU1NgqYVd+7CIiEgz17p1a1q3bh3sajQqDQnVwmHVPiwiIiLBpsBSi6ohIfWwiIiIBIsCSy0qVwlpDouIiEjwKLDUQmcJiYiIBJ8CSy3sFXNYPJrDIiIiEjQKLLVQD4uIiDSVlJQUZs2aFVBZi8XChx9+2Kj1aU4UWGqhfVhERESCT4GlFnbtdCsiIhJ0Ciy1qDqtWT0sIiJBYxjgLgrOI8Azgl955RWSk5Px/WLO4xVXXMENN9zA9u3bueKKK0hMTCQiIoIBAwbw5ZdfNtiP6IcffuDCCy8kNDSU+Ph4br75ZgoLC/3Pf/311wwcOJDw8HBiYmIYOnQo6enpAGzYsIELLriAyMhIoqKi6NevH999912D1a0haKfbWmgfFhGRZsBTDI8nB+e979sHzvBai1199dXcfvvtLFq0iOHDhwOQk5PD/Pnz+fTTTyksLGTUqFE89thjuFwuXn/9dS677DK2bt1Khw4dTqiKRUVFpKWlMXjwYFavXk12dja//e1vmTJlCq+99hrl5eWMGTOGm266ibfeegu3282qVauwWMzPuAkTJtC3b19efvllbDYb69evx+FwnFCdGpoCSy0qh4Q0h0VERI4nNjaWSy65hDfffNMfWN577z0SEhK44IILsFqt9O7d21/+0UcfZe7cucybN48pU6ac0Hu/+eablJaW8vrrrxMeboarF154gcsuu4w
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"tr_ind = list(range(len(train_losses)))\n",
|
|||
|
"val_ind = list(range(len(val_losses)))\n",
|
|||
|
"plt.plot(train_losses[1:], label='train_loss')\n",
|
|||
|
"plt.plot(val_losses[1:], label='val_loss')\n",
|
|||
|
"plt.legend(loc='best')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "2744f422-bdd2-4101-9c45-197ad32e8c22",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"best_mape = 1\n",
|
|||
|
"best_img = None\n",
|
|||
|
"best_mask = None\n",
|
|||
|
"best_recov = None\n",
|
|||
|
"test_miss_counts = list()\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask, r) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" test_miss_counts.append(r)\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n",
|
|||
|
" if mape < best_mape:\n",
|
|||
|
" best_recov = rev_recon[i][0].numpy()\n",
|
|||
|
" best_mask = used_mask.numpy()\n",
|
|||
|
" best_img = sample[0].numpy()\n",
|
|||
|
" best_mape = mape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"id": "e959a28a-840f-4b34-befc-c233f20635cc",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"id": "6ef3ffdf-72ea-4c88-8118-1103a81205f3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.264791</td>\n",
|
|||
|
" <td>1.798069</td>\n",
|
|||
|
" <td>0.161384</td>\n",
|
|||
|
" <td>0.680643</td>\n",
|
|||
|
" <td>0.889222</td>\n",
|
|||
|
" <td>0.836726</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.601222</td>\n",
|
|||
|
" <td>0.894735</td>\n",
|
|||
|
" <td>0.092427</td>\n",
|
|||
|
" <td>0.227477</td>\n",
|
|||
|
" <td>0.104377</td>\n",
|
|||
|
" <td>0.122876</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.377890</td>\n",
|
|||
|
" <td>0.487859</td>\n",
|
|||
|
" <td>0.045982</td>\n",
|
|||
|
" <td>-2.265916</td>\n",
|
|||
|
" <td>-0.146766</td>\n",
|
|||
|
" <td>0.002855</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>0.831340</td>\n",
|
|||
|
" <td>1.149141</td>\n",
|
|||
|
" <td>0.110199</td>\n",
|
|||
|
" <td>0.579173</td>\n",
|
|||
|
" <td>0.859047</td>\n",
|
|||
|
" <td>0.785617</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.126114</td>\n",
|
|||
|
" <td>1.609603</td>\n",
|
|||
|
" <td>0.142398</td>\n",
|
|||
|
" <td>0.736236</td>\n",
|
|||
|
" <td>0.922370</td>\n",
|
|||
|
" <td>0.869874</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.541714</td>\n",
|
|||
|
" <td>2.221009</td>\n",
|
|||
|
" <td>0.185216</td>\n",
|
|||
|
" <td>0.840757</td>\n",
|
|||
|
" <td>0.955571</td>\n",
|
|||
|
" <td>0.922865</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>4.765854</td>\n",
|
|||
|
" <td>8.694316</td>\n",
|
|||
|
" <td>1.285374</td>\n",
|
|||
|
" <td>0.988738</td>\n",
|
|||
|
" <td>0.997125</td>\n",
|
|||
|
" <td>0.994878</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 1.264791 1.798069 0.161384 0.680643 0.889222 \n",
|
|||
|
"std 0.601222 0.894735 0.092427 0.227477 0.104377 \n",
|
|||
|
"min 0.377890 0.487859 0.045982 -2.265916 -0.146766 \n",
|
|||
|
"25% 0.831340 1.149141 0.110199 0.579173 0.859047 \n",
|
|||
|
"50% 1.126114 1.609603 0.142398 0.736236 0.922370 \n",
|
|||
|
"75% 1.541714 2.221009 0.185216 0.840757 0.955571 \n",
|
|||
|
"max 4.765854 8.694316 1.285374 0.988738 0.997125 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.836726 \n",
|
|||
|
"std 0.122876 \n",
|
|||
|
"min 0.002855 \n",
|
|||
|
"25% 0.785617 \n",
|
|||
|
"50% 0.869874 \n",
|
|||
|
"75% 0.922865 \n",
|
|||
|
"max 0.994878 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 30,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "403385cd-0a5a-46ee-84a5-5c37848b87bf",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_counts_int = [int(y) for x in train_counts for y in x]\n",
|
|||
|
"val_counts_int = [int(y) for x in val_counts for y in x]\n",
|
|||
|
"test_counts_int = [int(y) for x in test_miss_counts for y in x]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"id": "e5a52567-71d1-4438-b89c-12ee499e3fb7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"26749"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"len(train_counts_int)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"id": "dbc0d21e-f303-4838-b9a5-8a3976c311ab",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from collections import Counter"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"id": "c674e143-5f70-4628-9adf-97f080617730",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"counts_train = Counter(train_counts_int)\n",
|
|||
|
"counts_valid = Counter(val_counts_int)\n",
|
|||
|
"counts_test = Counter(test_counts_int)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"id": "03bff9cc-8c7a-4cb9-bdf0-c163ed4763a7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"counts_df_train = pd.DataFrame.from_dict(dict(counts_train), orient='index').sort_index()\n",
|
|||
|
"counts_df_test = pd.DataFrame.from_dict(dict(counts_test), orient='index').sort_index()\n",
|
|||
|
"counts_df_valid = pd.DataFrame.from_dict(dict(counts_valid), orient='index').sort_index()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 36,
|
|||
|
"id": "a8f0bcf8-e33a-4603-a3a1-0676594ec54f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"rst = pd.concat([counts_df_train, counts_df_valid, counts_df_test], axis=1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 37,
|
|||
|
"id": "b91c40dc-9d20-400a-866e-472c9e4d81c3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"rst.columns = ['train', 'validation', 'test']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 56,
|
|||
|
"id": "528dd935-881e-4e37-95dd-89c1ae23566e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"rst.to_csv('./mix_eva.csv', index=False, encoding='utf-8-sig')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 57,
|
|||
|
"id": "f0c39db0-92f7-4fe3-a826-8185186c78c2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>train</th>\n",
|
|||
|
" <th>validation</th>\n",
|
|||
|
" <th>test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>10</th>\n",
|
|||
|
" <td>9624</td>\n",
|
|||
|
" <td>1500</td>\n",
|
|||
|
" <td>1743</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>20</th>\n",
|
|||
|
" <td>6534</td>\n",
|
|||
|
" <td>1117</td>\n",
|
|||
|
" <td>1150</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>30</th>\n",
|
|||
|
" <td>5380</td>\n",
|
|||
|
" <td>840</td>\n",
|
|||
|
" <td>956</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>40</th>\n",
|
|||
|
" <td>5211</td>\n",
|
|||
|
" <td>818</td>\n",
|
|||
|
" <td>890</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" train validation test\n",
|
|||
|
"10 9624 1500 1743\n",
|
|||
|
"20 6534 1117 1150\n",
|
|||
|
"30 5380 840 956\n",
|
|||
|
"40 5211 818 890"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 57,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"rst"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 53,
|
|||
|
"id": "9b900f09-65b3-45d8-99fd-486a80b51a3d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "6095f434-bc4d-4c90-9abd-e6e12c555f16",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"plt.figure(figsize=(16, 9))\n",
|
|||
|
"rst.plot.bar()\n",
|
|||
|
"plt.xlabel('Missing Rate(%)', fontsize=16, fontproperties='Times New Roman')\n",
|
|||
|
"plt.ylabel('Sample Counts', fontsize=16, fontproperties='Times New Roman')\n",
|
|||
|
"plt.xticks(rotation=45, fontproperties='Times New Roman')\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.legend(loc='best', fontsize=16)\n",
|
|||
|
"plt.savefig('./miss_counts.png')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "72bb4d0c-3fce-4b20-b5fa-7fca52cbb511",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|