MAE_ATMO/torch_MAE_1d_decoder.ipynb

1049 lines
58 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c724bfe5-69a4-441c-9571-02e736037bea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb6e75377f0>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(0)\n",
"torch.random.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e6cd4e9-6594-4eeb-82b8-94a5fc308b4b",
"metadata": {},
"outputs": [],
"source": [
"max_pixel_value = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7183fc4f-d0b2-4bdc-9ed3-52933d899686",
"metadata": {},
"outputs": [],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/20/'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8f76e514-7a5e-46f2-808a-07a33f212443",
"metadata": {},
"outputs": [],
"source": [
"train_set = NO2Dataset(image_dir, mask_dir)\n",
"train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n",
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "645114e8-65a4-4867-b3fe-23395288e855",
"metadata": {},
"outputs": [],
"source": [
"class Conv(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
" super(Conv, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
" bias=False):\n",
" super(ConvBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
" norm_layer(out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "31ecf247-e98b-4977-a145-782914a042bd",
"metadata": {},
"outputs": [],
"source": [
"class SeparableBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
" super(SeparableBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
" # 分离卷积,仅调整空间信息\n",
" norm_layer(in_channels), # 对输入通道进行归一化\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
" nn.ReLU6()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" # 如果输入和输出通道不一致,进行降采样操作\n",
" self.downsample = downsample\n",
" if in_channels != out_channels or stride != 1:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(out_channels)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttentionBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
" super(MultiHeadAttentionBlock, self).__init__()\n",
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
" B, C, H, W = x.shape\n",
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
"\n",
" # Apply multihead attention\n",
" attn_output, _ = self.attention(x, x, x)\n",
"\n",
" # Apply normalization and dropout\n",
" attn_output = self.norm(attn_output)\n",
" attn_output = self.dropout(attn_output)\n",
"\n",
" # Reshape back to (B, C, H, W)\n",
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
"\n",
" return attn_output"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
"metadata": {},
"outputs": [],
"source": [
"class SpatialAttentionBlock(nn.Module):\n",
" def __init__(self):\n",
" super(SpatialAttentionBlock, self).__init__()\n",
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
"\n",
" def forward(self, x): #(B, 64, H, W)\n",
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
" return x * out #(B, C, H, W)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
"metadata": {},
"outputs": [],
"source": [
"class DecoderAttentionBlock(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(DecoderAttentionBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
" self.spatial_attention = SpatialAttentionBlock()\n",
"\n",
" def forward(self, x):\n",
" # 通道注意力\n",
" b, c, h, w = x.size()\n",
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
"\n",
" avg_out = self.conv1(avg_pool)\n",
" max_out = self.conv1(max_pool)\n",
"\n",
" out = avg_out + max_out\n",
" out = torch.sigmoid(self.conv2(out))\n",
"\n",
" # 添加空间注意力\n",
" out = x * out\n",
" out = self.spatial_attention(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "08810d47-3af3-47de-81cc-0377c5cab16e",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" SEBlock(128, 128)\n",
" )\n",
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(32),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(16),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(X)\n",
" # loss = criterion(reconstructed, y)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" reconstructed = model(X)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(y))\n",
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2cb2da06-9180-43be-95bb-4ba06654bfc8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 2.1122538542205636, Val Loss: 0.17511643736220117\n",
"Epoch 2, Train Loss: 0.09455115371272324, Val Loss: 0.07173499481669113\n",
"Epoch 3, Train Loss: 0.05875080322142708, Val Loss: 0.05522163668230398\n",
"Epoch 4, Train Loss: 0.04709345177618083, Val Loss: 0.046923332583548416\n",
"Epoch 5, Train Loss: 0.04048821633975757, Val Loss: 0.04223129592502295\n",
"Epoch 6, Train Loss: 0.03651897717071207, Val Loss: 0.038725908567656335\n",
"Epoch 7, Train Loss: 0.03371973283606711, Val Loss: 0.03591352106252713\n",
"Epoch 8, Train Loss: 0.030995923611357737, Val Loss: 0.033181621734775714\n",
"Epoch 9, Train Loss: 0.02894393834575084, Val Loss: 0.031025866519159347\n",
"Epoch 10, Train Loss: 0.026934354539301122, Val Loss: 0.028885239290434923\n",
"Epoch 11, Train Loss: 0.025755781114422248, Val Loss: 0.027564026443148728\n",
"Epoch 12, Train Loss: 0.024294818880740535, Val Loss: 0.02660573101532993\n",
"Epoch 13, Train Loss: 0.023547336254179763, Val Loss: 0.025523469658262694\n",
"Epoch 14, Train Loss: 0.02263737249335176, Val Loss: 0.024892248685902625\n",
"Epoch 15, Train Loss: 0.02204986723389423, Val Loss: 0.02482297744101553\n",
"Epoch 16, Train Loss: 0.021457266258566005, Val Loss: 0.024080637119599242\n",
"Epoch 17, Train Loss: 0.020942402789681153, Val Loss: 0.023763289508312496\n",
"Epoch 18, Train Loss: 0.02059948215769096, Val Loss: 0.023712928865605325\n",
"Epoch 19, Train Loss: 0.020213669665050848, Val Loss: 0.022951017092190572\n",
"Epoch 20, Train Loss: 0.02002489379647246, Val Loss: 0.022396566457490424\n",
"Epoch 21, Train Loss: 0.019488899257818337, Val Loss: 0.02220052338914195\n",
"Epoch 22, Train Loss: 0.019191946226069657, Val Loss: 0.021812534682563882\n",
"Epoch 23, Train Loss: 0.018820160999894142, Val Loss: 0.021094122540150115\n",
"Epoch 24, Train Loss: 0.01841514516826808, Val Loss: 0.021011906689894732\n",
"Epoch 25, Train Loss: 0.01826861325392954, Val Loss: 0.020965722514622247\n",
"Epoch 26, Train Loss: 0.01783664010768159, Val Loss: 0.02035376571341237\n",
"Epoch 27, Train Loss: 0.01773165784883157, Val Loss: 0.020316684896599\n",
"Epoch 28, Train Loss: 0.017462643957362647, Val Loss: 0.020199675196364744\n",
"Epoch 29, Train Loss: 0.01726480335237806, Val Loss: 0.019924583983843894\n",
"Epoch 30, Train Loss: 0.017130774285412577, Val Loss: 0.019827198264981385\n",
"Epoch 31, Train Loss: 0.016821091141302192, Val Loss: 0.01998631670070228\n",
"Epoch 32, Train Loss: 0.016754478447887886, Val Loss: 0.019008648901510595\n",
"Epoch 33, Train Loss: 0.01657688988452893, Val Loss: 0.01900591877803429\n",
"Epoch 34, Train Loss: 0.016496175670613084, Val Loss: 0.019055584264891363\n",
"Epoch 35, Train Loss: 0.01644454181470583, Val Loss: 0.018636108959899908\n",
"Epoch 36, Train Loss: 0.01607896311823546, Val Loss: 0.018534055174286686\n",
"Epoch 37, Train Loss: 0.01588705154224945, Val Loss: 0.018062156513889333\n",
"Epoch 38, Train Loss: 0.015864519495962626, Val Loss: 0.018233197171296647\n",
"Epoch 39, Train Loss: 0.015855632771394755, Val Loss: 0.018038090332341727\n",
"Epoch 40, Train Loss: 0.015651265439982905, Val Loss: 0.01822574678530444\n",
"Epoch 41, Train Loss: 0.015510451237996372, Val Loss: 0.017679256400955256\n",
"Epoch 42, Train Loss: 0.015349842104436963, Val Loss: 0.018203645916794662\n",
"Epoch 43, Train Loss: 0.01543403383451358, Val Loss: 0.017195541675744663\n",
"Epoch 44, Train Loss: 0.015325402941233947, Val Loss: 0.017411370608788817\n",
"Epoch 45, Train Loss: 0.01518570597876202, Val Loss: 0.017076766354712978\n",
"Epoch 46, Train Loss: 0.014841953983182827, Val Loss: 0.016906344637608352\n",
"Epoch 47, Train Loss: 0.014843696093356068, Val Loss: 0.016789415712232022\n",
"Epoch 48, Train Loss: 0.014590430285104296, Val Loss: 0.01671677505347266\n",
"Epoch 49, Train Loss: 0.014620297918158569, Val Loss: 0.01652295997282907\n",
"Epoch 50, Train Loss: 0.014581651776654726, Val Loss: 0.01616852485866689\n",
"Epoch 51, Train Loss: 0.014414639787026569, Val Loss: 0.016296155653449138\n",
"Epoch 52, Train Loss: 0.01424450205157747, Val Loss: 0.016307457906207933\n",
"Epoch 53, Train Loss: 0.014137028997238173, Val Loss: 0.01646944234119867\n",
"Epoch 54, Train Loss: 0.014159051344939395, Val Loss: 0.016026857336844082\n",
"Epoch 55, Train Loss: 0.014192796753425347, Val Loss: 0.01584606984658028\n",
"Epoch 56, Train Loss: 0.013916373460076785, Val Loss: 0.015976423856371373\n",
"Epoch 57, Train Loss: 0.013736099040394195, Val Loss: 0.015810697172671112\n",
"Epoch 58, Train Loss: 0.013836662209276377, Val Loss: 0.015620186396721584\n",
"Epoch 59, Train Loss: 0.013784786091413367, Val Loss: 0.015319373792231972\n",
"Epoch 60, Train Loss: 0.013611769829497954, Val Loss: 0.015367041216857398\n",
"Epoch 61, Train Loss: 0.01358566418931815, Val Loss: 0.015289715783142331\n",
"Epoch 62, Train Loss: 0.013467149546093633, Val Loss: 0.015166739780289023\n",
"Epoch 63, Train Loss: 0.013366587792019668, Val Loss: 0.014960003544145556\n",
"Epoch 64, Train Loss: 0.013362093665971282, Val Loss: 0.015207788253675646\n",
"Epoch 65, Train Loss: 0.013282296849352322, Val Loss: 0.015704237049751317\n",
"Epoch 66, Train Loss: 0.013314912690553796, Val Loss: 0.015118209617351419\n",
"Epoch 67, Train Loss: 0.01314743113610448, Val Loss: 0.014853793154679128\n",
"Epoch 68, Train Loss: 0.013220271071125018, Val Loss: 0.015044791985358765\n",
"Epoch 69, Train Loss: 0.013089903819700035, Val Loss: 0.014621049485433458\n",
"Epoch 70, Train Loss: 0.013003655555591201, Val Loss: 0.015181626902142567\n",
"Epoch 71, Train Loss: 0.013071733119153377, Val Loss: 0.014468084979079553\n",
"Epoch 72, Train Loss: 0.013008178180555979, Val Loss: 0.014925862592992499\n",
"Epoch 73, Train Loss: 0.01300788912521096, Val Loss: 0.015519192122590186\n",
"Epoch 74, Train Loss: 0.012897961314001153, Val Loss: 0.014994534872361083\n",
"Epoch 75, Train Loss: 0.012850848984632766, Val Loss: 0.014727158249536557\n",
"Epoch 76, Train Loss: 0.012889095829380899, Val Loss: 0.014613447293861588\n",
"Epoch 77, Train Loss: 0.01279138982447497, Val Loss: 0.014250260944575516\n"
]
}
],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 100\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fb64e455b50>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1m0lEQVR4nO3deXxU9b3/8fc5M8mEQBa2JKBhsUYWoYBsBu7PNQqIFLW1lnILaLXXW2i11LZiK3V5aGy9WKwb1+tDuVoRawXqdUcQBYxsEisu1AUBlQSUkhAIJDPn+/tjziwRgoQk5wDzej4e8wg5c2bOZ77AzHu+53u+X8sYYwQAAOAT2+8CAABAaiOMAAAAXxFGAACArwgjAADAV4QRAADgK8IIAADwFWEEAAD4ijACAAB8FfS7gMPhOI6++OILZWVlybIsv8sBAACHwRij3bt3q2vXrrLtxvs/jokw8sUXX6iwsNDvMgAAwBHYunWrTjzxxEbvPybCSFZWlqToi8nOzva5GgAAcDiqq6tVWFgY/xxvzDERRmKnZrKzswkjAAAcY75piAUDWAEAgK8IIwAAwFeEEQAA4KtjYswIAOD4E4lEVF9f73cZaIZAIKBgMNjsaTcIIwAAz9XU1Oizzz6TMcbvUtBMmZmZ6tKli9LT04/4OQgjAABPRSIRffbZZ8rMzFTnzp2ZzPIYZYxRXV2dduzYoU2bNqmoqOiQE5sdCmEEAOCp+vp6GWPUuXNntWnTxu9y0Axt2rRRWlqaNm/erLq6OmVkZBzR8zCAFQDgC3pEjg9H2hvS4DlaoA4AAIAjRhgBAAC+IowAAOCxHj16aPbs2S3yXMuWLZNlWdq1a1eLPJ8fGMAKAMBhOOusszRw4MAWCRFr1qxR27Ztm1/UcSKlw8hDyz/RZ/+q1YRh3dSr4NArCgIAcCjGGEUiEQWD3/zR2rlzZw8qOnak9Gma597ZprlvfKrNX+3xuxQASFnGGO2tC/tyO9xJ16ZMmaLXXntNd999tyzLkmVZmjt3rizL0gsvvKDBgwcrFAppxYoV+vjjjzV+/Hjl5+erXbt2Gjp0qF555ZUGz/f10zSWZemhhx7SxRdfrMzMTBUVFemZZ5454jZ9+umndeqppyoUCqlHjx6aNWtWg/vvv/9+FRUVKSMjQ/n5+fre974Xv+9vf/ub+vfvrzZt2qhjx44qKSnRnj2t+zmZ0j0jQTt6WZnDDIAA4Jva+oj6znzJl2O/d8soZaZ/80fh3XffrX/+85/q16+fbrnlFknSu+++K0m6/vrr9V//9V866aST1L59e23dulUXXHCBbrvtNoVCIT366KMaN26cNm7cqG7dujV6jJtvvll//OMfdeedd+qee+7RxIkTtXnzZnXo0KFJr2ndunX6/ve/r5tuukmXXXaZ3njjDf30pz9Vx44dNWXKFK1du1Y///nP9dhjj2nEiBHauXOnli9fLknatm2bJkyYoD/+8Y+6+OKLtXv3bi1fvrzVZ8pN6TBiu9e4hx3CCACgcTk5OUpPT1dmZqYKCgokSR988IEk6ZZbbtF5550X37dDhw4aMGBA/Pdbb71VCxcu1DPPPKNp06Y1eowpU6ZowoQJkqTbb79df/7zn7V69WqNHj26SbXeddddOvfcc3XjjTdKkk455RS99957uvPOOzVlyhRt2bJFbdu21YUXXqisrCx1795dgwYNkhQNI+FwWJdccom6d+8uSerfv3+Tjn8kUjqMBAPRMBIhjACAb9qkBfTeLaN8O3ZzDRkypMHvNTU1uummm/Tcc8/FP9xra2u1ZcuWQz7Pt7/97fif27Ztq+zsbG3fvr3J9bz//vsaP358g20jR47U7NmzFYlEdN5556l79+466aSTNHr0aI0ePTp+emjAgAE699xz1b9/f40aNUrnn3++vve976l9+/ZNrqMpUnrMSMCdNY4wAgD+sSxLmelBX24tMQvs16+Kue6667Rw4ULdfvvtWr58ucrLy9W/f3/V1dUd8nnS0tIOaBfHcZpd39dlZWXprbfe0hNPPKEuXbpo5syZGjBggHbt2qVAIKDFixfrhRdeUN++fXXPPfeoV69e2rRpU4vXkSy1w4j7b5DTNACAb5Kenq5IJPKN+61cuVJTpkzRxRdfrP79+6ugoECffvpp6xfo6tOnj1auXHlATaeccooCgWhPUDAYVElJif74xz/qH//4hz799FMtXbpUUjQEjRw5UjfffLPWr1+v9PR0LVy4sFVrTunTNLGeEYcwAgD4Bj169NCqVav06aefql27do32WhQVFWnBggUaN26cLMvSjTfe2Co9HI355S9/qaFDh+rWW2/VZZddprKyMt177726//77JUnPPvusPvnkE51xxhlq3769nn/+eTmOo169emnVqlVasmSJzj//fOXl5WnVqlXasWOH+vTp06o1p3bPiPvq6RkBAHyT6667ToFAQH379lXnzp0bHQNy1113qX379hoxYoTGjRunUaNG6bTTTvOsztNOO01//etfNX/+fPXr108zZ87ULbfcoilTpkiScnNztWDBAp1zzjnq06eP5syZoyeeeEKnnnqqsrOz9frrr+uCCy7QKaecot/97neaNWuWxowZ06o1W6a1r9dpAdXV1crJyVFVVZWys7Nb7HmnPv6Wnntnm27+zqmaPKJHiz0vAKBx+/bt06ZNm9SzZ88jXnIeR49D/X0e7ud3iveMcDUNAAB+I4yIMAIAOHpdffXVateu3UFvV199td/ltYgUH8DqhpGj/0wVACBF3XLLLbruuusOel9LDl3wU2qHEYueEQDA0S0vL095eXl+l9GqUvs0jTvRSDhCGAEAwC9NCiOlpaUaOnSosrKylJeXp4suukgbN278xsc99dRT6t27tzIyMtS/f389//zzR1xwSwpymgYAAN81KYy89tprmjp1qt58800tXrxY9fX1Ov/88w+5tPAbb7yhCRMm6Mc//rHWr1+viy66SBdddJE2bNjQ7OKby46fpvFuMhoAANBQk8aMvPjiiw1+nzt3rvLy8rRu3TqdccYZB33M3XffrdGjR+tXv/qVpOjqhYsXL9a9996rOXPmHGHZLSPeM0IWAQDAN80aM1JVVSUpulxyY8rKylRSUtJg26hRo1RWVtboY/bv36/q6uoGt9YQCNAzAgCA3444jDiOo2uvvVYjR45Uv379Gt2voqJC+fn5Dbbl5+eroqKi0ceUlpYqJycnfissLDzSMg8pdjUN08EDALzQo0cPzZ49+7D2tSxLixYtatV6jhZHHEamTp2qDRs2aP78+S1ZjyRpxowZqqqqit+2bt3a4seQEqdpWCgPAAD/HNE8I9OmTdOzzz6r119/XSeeeOIh9y0oKFBlZWWDbZWVlSooKGj0MaFQSKFQ6EhKaxLbpmcEAAC/NalnxBijadOmaeHChVq6dKl69uz5jY8pLi7WkiVLGmxbvHixiouLm1ZpK4j3jHBpLwD4xxipbo8/tya8/z/44IPq2rWrnK+NMxw/fryuuOIKffzxxxo/frzy8/PVrl07DR06VK+88kqLNdM777yjc845R23atFHHjh31k5/8RDU1NfH7ly1bpmHDhqlt27bKzc3VyJEjtXnzZknS22+/rbPPPltZWVnKzs7W4MGDtXbt2harrbma1DMydepUzZs3T3//+9+VlZUVH/eRk5OjNm3aSJImTZqkE044QaWlpZKka665RmeeeaZmzZqlsWPHav78+Vq7dq0efPDBFn4pTRewo1mMSc8AwEf1e6Xbu/pz7Bu+kNLbHtaul156qX72s5/p1Vdf1bnnnitJ2rlzp1588UU9//zzqqmp0QUXXKDbbrtNoVBIjz76qMaNG6eNGzeqW7duzSpzz549GjVqlIqLi7VmzRpt375dV155paZNm6a5c+cqHA7roosu0lVXXaUnnnhCdXV1Wr16tSx3bOTEiRM1aNAgPfDAAwoEAiovL1daWlqzampJTQojDzzwgCTprLPOarD9kUce0ZQpUyRJW7ZskW0nOlxGjBihefPm6Xe/+51uuOEGFRUVadGiRYcc9OqVgFsm08EDAL5J+/btNWbMGM2bNy8eRv72t7+pU6dOOvvss2XbtgYMGBDf/9Zbb9XChQv1zDPPaNq0ac069rx587Rv3z49+uijats2Gp7uvfdejRs3Tn/4wx+UlpamqqoqXXjhhfrWt74lSerTp0/88Vu2bNGvfvUr9e7dW5JUVFTUrHpaWpPCiDmM7qxly5YdsO3SSy/VpZde2pRDeSLWM8IMrADgo7TMaA+FX8dugokTJ+qqq67S/fffr1AopMcff1w/+MEPZNu2ampqdNNNN+m5557Ttm3bFA6HVVtbqy1btjS7zPfff18DBgyIBxFJGjlypBzH0caNG3XGGWdoypQpGjVqlM477zyVlJTo+9//vrp06SJJmj59uq688ko99thjKikp0aWXXhoPLUeD1F6bJtp7xQBWAPCTZUVPlfhxc09jHK5x48bJGKPnnntOW7du1fLlyzVx4kRJ0nXXXaeFCxfq9ttv1/Lly1VeXq7+/furrq6uNVrtAI888ojKyso0YsQIPfnkkzrllFP05ptvSpJuuukmvfvuuxo7dqyWLl2qvn37auHChZ7UdThSO4y452kijBkBAByGjIwMXXLJJXr88cf1xBNPqFevXjrttNMkSStXrtSUKVN08cUXq3///iooKNCnn37aIsft06eP3n777QbLr6xcuVK2batXr17xbYMGDdKMGTP0xhtvqF+/fpo3b178vlNOOUW/+MUv9PLLL+uSSy7RI4880iK1tYSUDiMslAcAaKqJEyfqueee08MPPxzvFZGi4zAWLFig8vJyvf322/rhD394wJU3zTlmRkaGJk+erA0bNujVV1/Vz372M/3oRz9Sfn6+Nm3apBkzZqisrEybN2/Wyy+/rA8//FB9+vRRbW2tpk2bpmXLlmnz5s1auXKl1qxZ02BMid+OaJ6R40UgvlAeYQQAcHjOOeccdejQQRs3btQPf/jD+Pa77rpLV1xxhUaMGKFOnTrpN7/5TYstZ5KZmamXXnpJ11xzjYYOHarMzEx997vf1V133RW//4MPPtD//u//6quvvlKXLl00depU/cd//IfC4bC++uorTZo0SZWVlerUqZMuueQS3XzzzS1SW0uwzOGMSvVZdXW1cnJyVFVVpezs7BZ73qfXfaZfPvW2zjyls/73imEt9rwAgMbt27dPmzZtUs+ePZWRkeF3OWimQ/19Hu7nd2qfpgnQMwIAgN9SOozY8YXyWLUXAOCdxx9/XO3atTvo7dRTT/W7PM+l9JiRxEJ5PhcCAEgp3/nOdzR8+PCD3nc0zYzqlZQOI4mF8kgjAADvZGVlKSsry+8yjhopfZomcWmvz4UAQAo6Bq6fwGFoib/HlA4jgVgYoWcEADwTCAQkybOZSdG69u7dK6l5p5dS+jRNLIywai8AeCcYDCozM1M7duxQWlpag8VVcewwxmjv3r3avn27cnNz4yHzSBBGJDl0FQKAZyzLUpcuXbRp0yZt3rzZ73LQTLm5uSooKGjWc6R2GIlf2ksYAQAvpaenq6ioiFM1x7i0tLRm9YjEpHQYYdIzAPCPbdvMwApJKT+A1V21lzACAIBvUjuMsFAeAAC+S+0wYhNGAADwG2FEhBEAAPxEGBFX0wAA4KeUDiOJhfIIIwAA+CWlwwg9IwAA+I8wIinCDKwAAPgmpcNIkAGsAAD4LqXDiJ0URljKGgAAf6R0GIn1jEgSnSMAAPgjpcOInRRGwo7jYyUAAKSulA4jyT0jjBsBAMAfKR1GAoQRAAB8l9phxCKMAADgt9QOI/SMAADgu5QOI5ZlKZZHCCMAAPgjpcOIJAXtaBMwJTwAAP5I+TASYBZWAAB8RRghjAAA4CvCCIvlAQDgq5QPIyyWBwCAv1I+jMSmhA9HCCMAAPgh5cNIrGfE4TQNAAC+SPkwYruzsHJpLwAA/kj5MBIMxMaMsGovAAB+SPkwkri01+dCAABIUYSR+Gka0ggAAH4gjMQGsJJFAADwBWHEpmcEAAA/pXwYYdIzAAD8lfJhhLVpAADwF2GEMAIAgK8IIyyUBwCArwgj9IwAAOArwogdbQIWygMAwB8pH0aCnKYBAMBXKR9GYgvlcZoGAAB/pHwYCdqs2gsAgJ9SPowEArHp4AkjAAD4gTBi0TMCAICfUj6MBG16RgAA8FPKhxGbMSMAAPgq5cNIYqE8Vu0FAMAPKR9GEjOw+lwIAAApijBCzwgAAL4ijDADKwAAviKMcGkvAAC+Ioy4k55FWCgPAABfpHwYYaE8AAD8lfJhJMBCeQAA+IowYkebgDACAIA/Uj6MBAP0jAAA4KeUDyM2V9MAAOCrJoeR119/XePGjVPXrl1lWZYWLVp0yP2XLVsmy7IOuFVUVBxpzS2KhfIAAPBXk8PInj17NGDAAN13331NetzGjRu1bdu2+C0vL6+ph24VLJQHAIC/gk19wJgxYzRmzJgmHygvL0+5ublNflxrSyyURxgBAMAPno0ZGThwoLp06aLzzjtPK1euPOS++/fvV3V1dYNbawkQRgAA8FWrh5EuXbpozpw5evrpp/X000+rsLBQZ511lt56661GH1NaWqqcnJz4rbCwsNXqC3CaBgAAXzX5NE1T9erVS7169Yr/PmLECH388cf605/+pMcee+ygj5kxY4amT58e/726urrVAkksjDjMwAoAgC9aPYwczLBhw7RixYpG7w+FQgqFQp7UwkJ5AAD4y5d5RsrLy9WlSxc/Dn2AxKRnjs+VAACQmprcM1JTU6OPPvoo/vumTZtUXl6uDh06qFu3bpoxY4Y+//xzPfroo5Kk2bNnq2fPnjr11FO1b98+PfTQQ1q6dKlefvnllnsVzcAAVgAA/NXkMLJ27VqdffbZ8d9jYzsmT56suXPnatu2bdqyZUv8/rq6Ov3yl7/U559/rszMTH3729/WK6+80uA5/MRCeQAA+Msy5ugfuVldXa2cnBxVVVUpOzu7RZ/75Xcr9JPH1um0brla8NORLfrcAACkssP9/E75tWlYKA8AAH+lfBhhoTwAAPyV8mEkaEebgJ4RAAD8kfJhxM0ihBEAAHyS8mGEnhEAAPyV8mEkPs/I0X9REQAAxyXCSGyhvAhhBAAAP6R8GAmyUB4AAL5K+TDCpb0AAPgr5cMIk54BAOCvlA8jLJQHAIC/CCMslAcAgK8II/SMAADgK8IIYQQAAF+lfBiJXdobdhyfKwEAIDWlfBgJxOcZkQxzjQAA4DnCiBtGJE7VAADgB8JIUhhh4jMAALyX8mEktmqvxJTwAAD4IeXDSFIWoWcEAAAfpHwYadAzQhgBAMBzKR9GkoaM0DMCAIAPUj6MWJbFxGcAAPgo5cOIxCysAAD4iTAiFssDAMBPhBElpoQnjAAA4D3CiCQ7vj4NYQQAAK8RRkTPCAAAfiKMiAGsAAD4iTAiwggAAH4ijCgRRsKO43MlAACkHsKIEmNGWCgPAADvEUaUdDVNhDACAIDXCCNKupqGnhEAADxHGJFkMwMrAAC+IYxICgaY9AwAAL8QRiQF7GgzOIQRAAA8RxiR5HaM0DMCAIAPCCOSgvSMAADgG8KIJDeL0DMCAIAPCCNK9IxwNQ0AAN4jjIi1aQAA8BNhRIQRAAD8RBhR8kJ5hBEAALxGGBHTwQMA4CfCiBIL5UUijs+VAACQeggjSu4Z8bkQAABSEGFEUiC+UB49IwAAeI0wIgawAgDgJ8KIEqv2Mh08AADeI4xIsi16RgAA8AthRIkBrPSMAADgPcKIEpf20jMCAID3CCNKurSXMAIAgOcII5ICrNoLAIBvCCOSAm4rcJoGAADvEUaU6BlxWJsGAADPEUaUmIGVnhEAALxHGFFi0rMIi9MAAOA5wogS08FHOE0DAIDnCCNKXiiPMAIAgNcII2KhPAAA/EQYEQvlAQDgJ8KIkhfKc3yuBACA1EMYUfJ08D4XAgBACiKMKLFQXoSeEQAAPEcYUaJnhAGsAAB4jzCixNU0TAcPAID3CCNKurSXGVgBAPBck8PI66+/rnHjxqlr166yLEuLFi36xscsW7ZMp512mkKhkE4++WTNnTv3CEptPUF6RgAA8E2Tw8iePXs0YMAA3XfffYe1/6ZNmzR27FidffbZKi8v17XXXqsrr7xSL730UpOLbS02C+UBAOCbYFMfMGbMGI0ZM+aw958zZ4569uypWbNmSZL69OmjFStW6E9/+pNGjRrV1MO3ivhCeYQRAAA81+pjRsrKylRSUtJg26hRo1RWVtboY/bv36/q6uoGt9YUsKPNQBgBAMB7rR5GKioqlJ+f32Bbfn6+qqurVVtbe9DHlJaWKicnJ34rLCxs1RpZKA8AAP8clVfTzJgxQ1VVVfHb1q1bW/V4LJQHAIB/mjxmpKkKCgpUWVnZYFtlZaWys7PVpk2bgz4mFAopFAq1dmlxLJQHAIB/Wr1npLi4WEuWLGmwbfHixSouLm7tQx82rqYBAMA/TQ4jNTU1Ki8vV3l5uaTopbvl5eXasmWLpOgplkmTJsX3v/rqq/XJJ5/o17/+tT744APdf//9+utf/6pf/OIXLfMKWkBioTzCCAAAXmtyGFm7dq0GDRqkQYMGSZKmT5+uQYMGaebMmZKkbdu2xYOJJPXs2VPPPfecFi9erAEDBmjWrFl66KGHjprLeqXEmBHCCAAA3mvymJGzzjpL5hAzlR5sdtWzzjpL69evb+qhPMMAVgAA/HNUXk3jNaaDBwDAP4QRSXZ8oTzH50oAAEg9hBEl94z4XAgAACmIMKLkS3vpGQEAwGuEEbFQHgAAfiKMiEt7AQDwE2FEiYXyHMOU8AAAeI0wIiloJ5ohwuW9AAB4ijAiKSmLcKoGAACPEUb0tZ4RwggAAJ4ijCgxgFXiNA0AAF4jjOhrYSRCGAEAwEuEEUlJWYTF8gAA8BhhRJJlWSyWBwCATwgjrvhiefSMAADgKcKIK94zQhgBAMBThBFXwKJnBAAAPxBGXIH4Ynms3AsAgJcII65gfLE8nwsBACDFEEZcdvw0DWkEAAAvEUZciZ4RxowAAOAlwojLJowAAOALwoiLnhEAAPxBGHEFCCMAAPiCMOIijAAA4A/CiCtgR5uCSc8AAPAWYcQVHzPCQnkAAHiKMOKKX00TIYwAAOAlwoiLnhEAAPxBGHHFFspjACsAAN4ijLhiV9MwgBUAAG8RRlxBd9VehzACAICnCCOuxEJ5hBEAALxEGHElpoNn1V4AALxEGHElFsrzuRAAAFIMYcRFzwgAAP4gjLhYmwYAAH8QRlxc2gsAgD8IIy56RgAA8AdhxMV08AAA+IMw4gqwUB4AAL4gjLgC9IwAAOALwoiLhfIAAPAHYcQVsKNNwdU0AAB4izDiYqE8AAD8QRhxsVAeAAD+IIy4gswzAgCALwgjLpswAgCALwgjriDTwQMA4AvCiCs2zwgDWAEA8BZhxMVCeQAA+IMw4koMYHV8rgQAgNRCGHHFLu1laRoAALxFGHHFJj2jZwQAAG8RRlwBLu0FAMAXhBEXC+UBAOAPwoiLq2kAAPAHYcSVGDNCGAEAwEuEEZfNaRoAAHxBGHEF7WhTEEYAAPAWYcQVcFuCMAIAgLcII66A2zPCAFYAALxFGHHFpoN3DGEEAAAvEUZcduzSXuaDBwDAU4QRV5AZWAEA8AVhxJVYKI8wAgCAlwgjLiY9AwDAH4QRFwvlAQDgD8KIi4XyAADwxxGFkfvuu089evRQRkaGhg8frtWrVze679y5c2VZVoNbRkbGERfcWhIL5Tk+VwIAQGppchh58sknNX36dP3+97/XW2+9pQEDBmjUqFHavn17o4/Jzs7Wtm3b4rfNmzc3q+jWkBgz4nMhAACkmCaHkbvuuktXXXWVLr/8cvXt21dz5sxRZmamHn744UYfY1mWCgoK4rf8/PxmFd0aEqdpSCMAAHipSWGkrq5O69atU0lJSeIJbFslJSUqKytr9HE1NTXq3r27CgsLNX78eL377ruHPM7+/ftVXV3d4NbaGMAKAIA/mhRGvvzyS0UikQN6NvLz81VRUXHQx/Tq1UsPP/yw/v73v+svf/mLHMfRiBEj9NlnnzV6nNLSUuXk5MRvhYWFTSnziBBGAADwR6tfTVNcXKxJkyZp4MCBOvPMM7VgwQJ17txZ//3f/93oY2bMmKGqqqr4bevWra1dZtIAVsIIAABeCjZl506dOikQCKiysrLB9srKShUUFBzWc6SlpWnQoEH66KOPGt0nFAopFAo1pbRmC7qr9rJQHgAA3mpSz0h6eroGDx6sJUuWxLc5jqMlS5aouLj4sJ4jEononXfeUZcuXZpWaStzswg9IwAAeKxJPSOSNH36dE2ePFlDhgzRsGHDNHv2bO3Zs0eXX365JGnSpEk64YQTVFpaKkm65ZZbdPrpp+vkk0/Wrl27dOedd2rz5s268sorW/aVNFOsZ8QYyXFMfBVfAADQupocRi677DLt2LFDM2fOVEVFhQYOHKgXX3wxPqh1y5Ytsu1Eh8u//vUvXXXVVaqoqFD79u01ePBgvfHGG+rbt2/LvYoWELu0V4oulmeLMAIAgBcsY47+QRLV1dXKyclRVVWVsrOzW+UYNfvD6vf7lyRJH9w6WhlpgVY5DgAAqeJwP79Zm8YVTDotw+W9AAB4hzDispNO0zCIFQAA7xBGXPSMAADgD8KIy7YtxTpHCCMAAHiHMJIksVgeYQQAAK8QRpLE16c5+i8wAgDguEEYSRIPIxHCCAAAXiGMJEksluf4XAkAAKmDMJIkdkUNi+UBAOAdwkiSRM8IYQQAAK8QRpLEwwhjRgAA8AxhJEns0l5O0wAA4B3CSJJAgNM0AAB4jTCSJGhHm8MhjAAA4BnCSJLY8jT0jAAA4B3CSJJYzwjTwQMA4B3CSBLbZm0aAAC8RhhJEiSMAADgOcJIkgBhBAAAzxFGkjADKwAA3iOMJKFnBAAA7xFGksTHjDADKwAAniGMJEn0jDg+VwIAQOogjCRhoTwAALxHGEnCQnkAAHiPMJKEq2kAAPAeYSRJ0F21l4XyAADwDmEkiW3RMwIAgNcII0mYDh4AAO8RRpKwUB4AAN4jjCQJMoAVAADPEUaSBOxoczCAFQAA7xBGkgTc1qBnBAAA7xBGkgTdnhHGjAAA4B3CSJIAC+UBAOA5wkiSAFfTAADgOcJIEsIIAADeI4wkiS2URxgBAMA7hJEkiYXyHJ8rAQAgdRBGkiSmg/e5EAAAUghhJEliOnjSCAAAXiGMJGE6eAAAvEcYSRIbM8J08AAAeIcwkiRAzwgAAJ5L3TBijLTkVmnOv0m7tkhKnKZxmIEVAADPpG4YsSzp0+VSxTvSx69KSgxgDUcIIwAAeCV1w4gkfeuc6M+Pl0hKvrSXMAIAgFdSPIycG/35yTLJici2WCgPAACvpXYY6TpIysiR9lVJn7+lYICeEQAAvJbaYSQQlE46K/rnj5cqYEebgzACAIB3UjuMSA3GjcQWyuPSXgAAvEMYiYWRz9YqFN4tiZ4RAAC8RBjJ7SZ1LJJMRJ2/fFMSYQQAAC8RRqR470inypWSCCMAAHiJMCJJJ0cv8e1YsUKSYcwIAAAeIoxIUveRkp2mjD2fqYdVwUJ5AAB4iDAiSaF2UrfTJUln2P9Q2HF8LggAgNRBGIlxx438P/sd1bM2DQAAniGMxLjjRort97RtZ7X+tu4znwsCACA1EEZi8vtLmZ3Uztqn06wPdcPCd/TOZ1V+VwUAwHGPMBJj2/FTNZPyPlJd2NHVf1mnr2r2+1wYAADHN8JIMjeMjKl7WVflrNHnu/Zq2rz1CkcY0AoAQGshjCTrPVbq1Et27Vf67f4/aWHoZtVsWqM7XvjA78oAADhuWcaYo/7SkerqauXk5KiqqkrZ2dmte7D6fdKb90mvz5Lq98gxlp6O/D992HW8hp0xRmf37aqAbbVuDQAAHAcO9/ObMNLoQb+QXrlJ+seT8U1VJlNrAgNlTj5fp464QF26Fcmy6VwCAOBgCCMtZetq1Sx/QPYnS5QZbnh1zR5laEeoh8IditT2xH7qcGKRQp16SDndpLadJIseFABA6iKMtDQnov2bV2tT2QKlf7JUhfWblGZFGt29zgppT0a+6trkybQrUCC7izI6dFVmh64KZHeR2uVLWflSRi6hBQBwXDrcz++ghzUd2+yAQj2L1btnsSRpb22tNrz/trZ9WK7aL95TZtVH6uzs0InWDuVpl9K1X+m1W6TaLdLOxp82rKBq03K1Py1H4VB7OW06yMrsoGC7jkpr10mh7E7KyO4oK72dlN5OSm8bvaVlSsGQlNZGsgMeNQIAAC2PnpEWtGtvnTZ/tVdbtv9LO7dt0r6vtkq7KxTYW6mMfTuUVf+lOqlKedYu5Vn/Uo61t0WOG1FA9XZIYTtD4UCGIoGQwoE2Mna6FEiTAulSIF1WIE2WZcm2JMuyZdmWLDsoBdNlBdJlBUOyg+myAkHZdkB2ICDLDsi2A7IsS5Zlxx+rQFAKhKRgevRnID26zU6LHtNOi87dIkuy7MTNDkh20P1zsEF9Coaiv8tK6i1yfyb/blnRx8aPAQA4GrVqz8h9992nO++8UxUVFRowYIDuueceDRs2rNH9n3rqKd1444369NNPVVRUpD/84Q+64IILjuTQR7XczHTlZqZrQGGupJ4H3B9xjHbuqdP23fu0fvd+fbmrWrX/qlB9zZeK1OyU9n4pe9+/lF63S+l1u9QmXKUcs1s5Vo0ytV+Z1n611T5lap9CVjj+vAFFFHD2Ss5eKXzAYY9rjgKK2EE5SvQOmQZBJhFsjCwZy3ZvAUm2u6+V9DhLlowsY2TJkSUjxwoqEggpEsiQY4cUsdMly25wds2SiR8xUYctY0WDl2MHJcuSZQXcn9GbrKCMHYyHNGPbsoyR3OPLGLf+aKhrUKMUrdE40fvstHjAM3ZAtonIcsLRmwnLkiUnmCETzJATiP6UFD2eHPc1KxrwrGD0px2QbduyLVuWbUdDrOW2qWXHa4u+HsVD6wH7xANpIPparYCMbct2j29Jsiwjq0ELRv/WZIxkHPfPTrQt7KB7Szxfg2PJSnqMif6UEvfF67G+Vt/X63TD7sG+szX4B5D8b+1gz2dHa3Ai0bqMI5mIFAlLTlhy6qM/LTsR6APu32dyLcn1xJ5HJqkt3HYxjhRxnzNSHz2WHWwY/C07cb8Tjtbm1EuRusQ24ySeM/4lI7mtAw1fe3L72kHvviw4jhTel7jZaVJ6ZrQHmdPgx4Qmh5Enn3xS06dP15w5czR8+HDNnj1bo0aN0saNG5WXl3fA/m+88YYmTJig0tJSXXjhhZo3b54uuugivfXWW+rXr1+LvIhjRcC21DkrpM5ZIXdLnqSTD/mYffURVdfWq7Y+ot11EW2vj6i2LqKaffu1b+8e7avdq/21e1W/f49Mfa1UXyurvlZWuFYK75cTrpMTqZfCdVKkThFj5Dgm/lNOWAGnXgFTp6CpV8Cpi364GUe2HAXkyI5/0MbekI3SFFG6FVa66pWusNIUVpoiClrRn2kKy45+/MuSkS0Tf76AHAWsiIJyFHQfl66wQlZ9k9vUVkS20/jYHQD+irj/85NF30tMPFTLfY84GMeN3bFHSJaMZcUfbRtHaWr8vWO/laF6K93d25FtnPh7WfydzYodw71Z0XctK/m9z8RqdL+oxL4wSHIs953N/Zn8hcgkBeyGX1nMAa8sttUk1XRwDe9v0HJGSa8ncTzb/XJlm0i8LZK/dFnGqPrSJ1XQu7jRtmxNTT5NM3z4cA0dOlT33nuvJMlxHBUWFupnP/uZrr/++gP2v+yyy7Rnzx49++yz8W2nn366Bg4cqDlz5hzWMY+V0zTHE2OMHCPVRxxFksJLJHYzSX92b0aSY4yMifYCOcaoPmIUjjgKO0ZhJ/ocjvvY6E/FnzscceRE6uS42yKRiMKOkXGc6BdBRZ/bGEeKhKMhK1LnfpsLR/9LO070P6C7nzGOW1P0eSx3m2Ui0W+C7muNfYM2xsiRpbBjKaLoFy5bEQWd/Uoz+5Vu9ivo1McfZ2KPN4k3hNh8vbZxZJuIAu7bcezYJlaPiW6zTcTdLyzbOHKspBgX+0LvviXbbm+Jk3gLkWOi9wfcYwVMWAE5CstWvQlET+MpIFuOMlSnDNUppDqFVC+3T0SO+1YlEz2G7b5pBxSRZKIBVbE3tGgDWw3eTBu+7dpuK8SCaMCKVhsPo4ooIMd9lqQ3YPdRyZzEW6eM2ysUkKOgEoHWSjpmrMboa7Pj1SbXGQvHiR4ZEw/esdcf+z3xBtmg3yu+xfpaWyTf7NjNMnJM9DVE3JjuyFY4+reusIKKyJbtvq5YoA/GPsqtRj6ojftB08j9khQ2tiIKKKhwo8+T/Hz1CqouurccWW47R+J1HepYR4OwsRW0mDX7SHxw4QL1HnJuiz5nq5ymqaur07p16zRjxoz4Ntu2VVJSorKysoM+pqysTNOnT2+wbdSoUVq0aFGjx9m/f7/270+sCVNdXd2UMtECLMtSwJICDI6FKxZQnYN8fzHudidpn8a+05mkx8jd10huoDXutmj4dIwb+pIOGT9z9Q0cNyTGgm8s0DpOIthK7tkFWQ2eM3ZfIgBH/xxxEoU0eB2x/ZK2xdor9pq+qS2MDnaaKqkdHCcaoK3oKcDY6bHo440sJyKZsEy4Xsay5binAGMhz5IkE1HAqZft1MsyETl2WjRy2dFv9bIO/P8ee31O/IxXNObJRKIh1XFkWW6kNLHjOJIJJ05LOeED/tKMifYrRJzonyOxgOf+fcSawnIkIyf+pSF+Wi8ejC2FA+mKBDIUtkJy7KCM4ygQqZUdrlUgvFd2pE6yoyEwfiopqX0t97RX7IuCjCPjRGS5PTBye0riPQ6WnfjdMdEeWhORZSKyTdhtMyfp366JjreTZDf4d2YpIssNql8LtsY0qDH2l2HcLwZO7IuZJFuWLCtxjK+fnjQmGswjsuWY6BctYwUT7WFF+6ZG9Rh0wN+/V5oURr788ktFIhHl5+c32J6fn68PPjj4lOkVFRUH3b+ioqLR45SWlurmm29uSmkAWlk8oDYaMwDgyByVlyLMmDFDVVVV8dvWrVv9LgkAALSSJvWMdOrUSYFAQJWVlQ22V1ZWqqCg4KCPKSgoaNL+khQKhRQKhRq9HwAAHD+a1DOSnp6uwYMHa8mSJfFtjuNoyZIlKi4++Ajc4uLiBvtL0uLFixvdHwAApJYmX9o7ffp0TZ48WUOGDNGwYcM0e/Zs7dmzR5dffrkkadKkSTrhhBNUWloqSbrmmmt05plnatasWRo7dqzmz5+vtWvX6sEHH2zZVwIAAI5JTQ4jl112mXbs2KGZM2eqoqJCAwcO1IsvvhgfpLplyxbZSRPdjBgxQvPmzdPvfvc73XDDDSoqKtKiRYtSbo4RAABwcEwHDwAAWsXhfn4flVfTAACA1EEYAQAAviKMAAAAXxFGAACArwgjAADAV4QRAADgK8IIAADwVZMnPfNDbCqU6upqnysBAACHK/a5/U1Tmh0TYWT37t2SpMLCQp8rAQAATbV7927l5OQ0ev8xMQOr4zj64osvlJWVJcuyWux5q6urVVhYqK1btzKzayujrb1DW3uL9vYObe2dlmprY4x2796trl27Nlgq5uuOiZ4R27Z14oknttrzZ2dn8w/bI7S1d2hrb9He3qGtvdMSbX2oHpEYBrACAABfEUYAAICvUjqMhEIh/f73v1coFPK7lOMebe0d2tpbtLd3aGvveN3Wx8QAVgAAcPxK6Z4RAADgP8IIAADwFWEEAAD4ijACAAB8ldJh5L777lOPHj2UkZGh4cOHa/Xq1X6XdMwrLS3V0KFDlZWVpby8PF100UXauHFjg3327dunqVOnqmPHjmrXrp2++93vqrKy0qeKjw933HGHLMvStddeG99GO7eszz//XP/+7/+ujh07qk2bNurfv7/Wrl0bv98Yo5kzZ6pLly5q06aNSkpK9OGHH/pY8bEpEonoxhtvVM+ePdWmTRt961vf0q233tpgbRPa+si8/vrrGjdunLp27SrLsrRo0aIG9x9Ou+7cuVMTJ05Udna2cnNz9eMf/1g1NTXNL86kqPnz55v09HTz8MMPm3fffddcddVVJjc311RWVvpd2jFt1KhR5pFHHjEbNmww5eXl5oILLjDdunUzNTU18X2uvvpqU1hYaJYsWWLWrl1rTj/9dDNixAgfqz62rV692vTo0cN8+9vfNtdcc018O+3ccnbu3Gm6d+9upkyZYlatWmU++eQT89JLL5mPPvoovs8dd9xhcnJyzKJFi8zbb79tvvOd75iePXua2tpaHys/9tx2222mY8eO5tlnnzWbNm0yTz31lGnXrp25++674/vQ1kfm+eefN7/97W/NggULjCSzcOHCBvcfTruOHj3aDBgwwLz55ptm+fLl5uSTTzYTJkxodm0pG0aGDRtmpk6dGv89EomYrl27mtLSUh+rOv5s377dSDKvvfaaMcaYXbt2mbS0NPPUU0/F93n//feNJFNWVuZXmces3bt3m6KiIrN48WJz5plnxsMI7dyyfvOb35h/+7d/a/R+x3FMQUGBufPOO+Pbdu3aZUKhkHniiSe8KPG4MXbsWHPFFVc02HbJJZeYiRMnGmNo65by9TByOO363nvvGUlmzZo18X1eeOEFY1mW+fzzz5tVT0qepqmrq9O6detUUlIS32bbtkpKSlRWVuZjZcefqqoqSVKHDh0kSevWrVN9fX2Dtu/du7e6detG2x+BqVOnauzYsQ3aU6KdW9ozzzyjIUOG6NJLL1VeXp4GDRqk//mf/4nfv2nTJlVUVDRo75ycHA0fPpz2bqIRI0ZoyZIl+uc//ylJevvtt7VixQqNGTNGEm3dWg6nXcvKypSbm6shQ4bE9ykpKZFt21q1alWzjn9MLJTX0r788ktFIhHl5+c32J6fn68PPvjAp6qOP47j6Nprr9XIkSPVr18/SVJFRYXS09OVm5vbYN/8/HxVVFT4UOWxa/78+Xrrrbe0Zs2aA+6jnVvWJ598ogceeEDTp0/XDTfcoDVr1ujnP/+50tPTNXny5HibHuw9hfZumuuvv17V1dXq3bu3AoGAIpGIbrvtNk2cOFGSaOtWcjjtWlFRoby8vAb3B4NBdejQodltn5JhBN6YOnWqNmzYoBUrVvhdynFn69atuuaaa7R48WJlZGT4Xc5xz3EcDRkyRLfffrskadCgQdqwYYPmzJmjyZMn+1zd8eWvf/2rHn/8cc2bN0+nnnqqysvLde2116pr16609XEsJU/TdOrUSYFA4IArCyorK1VQUOBTVceXadOm6dlnn9Wrr76qE088Mb69oKBAdXV12rVrV4P9afumWbdunbZv367TTjtNwWBQwWBQr732mv785z8rGAwqPz+fdm5BXbp0Ud++fRts69Onj7Zs2SJJ8TblPaX5fvWrX+n666/XD37wA/Xv318/+tGP9Itf/EKlpaWSaOvWcjjtWlBQoO3btze4PxwOa+fOnc1u+5QMI+np6Ro8eLCWLFkS3+Y4jpYsWaLi4mIfKzv2GWM0bdo0LVy4UEuXLlXPnj0b3D948GClpaU1aPuNGzdqy5YttH0TnHvuuXrnnXdUXl4evw0ZMkQTJ06M/5l2bjkjR4484BL1f/7zn+revbskqWfPniooKGjQ3tXV1Vq1ahXt3UR79+6VbTf8aAoEAnIcRxJt3VoOp12Li4u1a9curVu3Lr7P0qVL5TiOhg8f3rwCmjX89Rg2f/58EwqFzNy5c817771nfvKTn5jc3FxTUVHhd2nHtP/8z/80OTk5ZtmyZWbbtm3x2969e+P7XH311aZbt25m6dKlZu3ataa4uNgUFxf7WPXxIflqGmNo55a0evVqEwwGzW233WY+/PBD8/jjj5vMzEzzl7/8Jb7PHXfcYXJzc83f//53849//MOMHz+ey02PwOTJk80JJ5wQv7R3wYIFplOnTubXv/51fB/a+sjs3r3brF+/3qxfv95IMnfddZdZv3692bx5szHm8Np19OjRZtCgQWbVqlVmxYoVpqioiEt7m+uee+4x3bp1M+np6WbYsGHmzTff9LukY56kg94eeeSR+D61tbXmpz/9qWnfvr3JzMw0F198sdm2bZt/RR8nvh5GaOeW9X//93+mX79+JhQKmd69e5sHH3ywwf2O45gbb7zR5Ofnm1AoZM4991yzceNGn6o9dlVXV5trrrnGdOvWzWRkZJiTTjrJ/Pa3vzX79++P70NbH5lXX331oO/PkydPNsYcXrt+9dVXZsKECaZdu3YmOzvbXH755Wb37t3Nrs0yJmlaOwAAAI+l5JgRAABw9CCMAAAAXxFGAACArwgjAADAV4QRAADgK8IIAADwFWEEAAD4ijACAAB8RRgBAAC+IowAAABfEUYAAICvCCMAAMBX/x/duxyKSVWRCgAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "cadb0e00-96bb-423b-9163-7c8010011dd1",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "4510b043-7808-4679-9be4-c61dcca6ecac",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" # tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
" # tr_mins = np.transpose(mins, (2, 0, 1))\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" eva_list.append([mae, rmse, mape, r2])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "4d80bff2-3086-4e73-a597-f2fa812e2c28",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.548639</td>\n",
" <td>2.513043</td>\n",
" <td>0.190712</td>\n",
" <td>0.850014</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.104697</td>\n",
" <td>0.277761</td>\n",
" <td>0.018381</td>\n",
" <td>0.021919</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.372461</td>\n",
" <td>2.125686</td>\n",
" <td>0.158994</td>\n",
" <td>0.766183</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.492424</td>\n",
" <td>2.371325</td>\n",
" <td>0.177162</td>\n",
" <td>0.836254</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.553864</td>\n",
" <td>2.482061</td>\n",
" <td>0.187778</td>\n",
" <td>0.851790</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.600554</td>\n",
" <td>2.630040</td>\n",
" <td>0.201229</td>\n",
" <td>0.865281</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>2.036150</td>\n",
" <td>4.280405</td>\n",
" <td>0.259433</td>\n",
" <td>0.884967</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2\n",
"count 75.000000 75.000000 75.000000 75.000000\n",
"mean 1.548639 2.513043 0.190712 0.850014\n",
"std 0.104697 0.277761 0.018381 0.021919\n",
"min 1.372461 2.125686 0.158994 0.766183\n",
"25% 1.492424 2.371325 0.177162 0.836254\n",
"50% 1.553864 2.482061 0.187778 0.851790\n",
"75% 1.600554 2.630040 0.201229 0.865281\n",
"max 2.036150 4.280405 0.259433 0.884967"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "9732912d-4fa2-42c5-8c7d-27825e479faf",
"metadata": {},
"outputs": [],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe().to_csv('./eva_files/decoder+local_loss.csv', encoding='utf-8-sig')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "699473c7-33b8-432d-861c-2628ad2614f0",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "79731bcf-3ec2-4a9b-a58d-74c40212f738",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.553667</td>\n",
" <td>2.209092</td>\n",
" <td>0.188788</td>\n",
" <td>0.523867</td>\n",
" <td>0.829028</td>\n",
" <td>0.775553</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.821044</td>\n",
" <td>1.193856</td>\n",
" <td>0.121753</td>\n",
" <td>0.420704</td>\n",
" <td>0.182549</td>\n",
" <td>0.164661</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.525306</td>\n",
" <td>0.680506</td>\n",
" <td>0.061413</td>\n",
" <td>-4.738533</td>\n",
" <td>-0.916011</td>\n",
" <td>-0.197854</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.960099</td>\n",
" <td>1.333764</td>\n",
" <td>0.131694</td>\n",
" <td>0.429017</td>\n",
" <td>0.802631</td>\n",
" <td>0.715950</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.369256</td>\n",
" <td>1.958160</td>\n",
" <td>0.163652</td>\n",
" <td>0.646098</td>\n",
" <td>0.889664</td>\n",
" <td>0.824197</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.892561</td>\n",
" <td>2.704055</td>\n",
" <td>0.203364</td>\n",
" <td>0.768918</td>\n",
" <td>0.931843</td>\n",
" <td>0.886272</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>7.905261</td>\n",
" <td>11.196068</td>\n",
" <td>1.671224</td>\n",
" <td>0.972414</td>\n",
" <td>0.993103</td>\n",
" <td>0.986316</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 1.553667 2.209092 0.188788 0.523867 0.829028 \n",
"std 0.821044 1.193856 0.121753 0.420704 0.182549 \n",
"min 0.525306 0.680506 0.061413 -4.738533 -0.916011 \n",
"25% 0.960099 1.333764 0.131694 0.429017 0.802631 \n",
"50% 1.369256 1.958160 0.163652 0.646098 0.889664 \n",
"75% 1.892561 2.704055 0.203364 0.768918 0.931843 \n",
"max 7.905261 11.196068 1.671224 0.972414 0.993103 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.775553 \n",
"std 0.164661 \n",
"min -0.197854 \n",
"25% 0.715950 \n",
"50% 0.824197 \n",
"75% 0.886272 \n",
"max 0.986316 "
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8b5207d-e9ad-46e7-8d57-18528beee59b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}