{ "cells": [ { "cell_type": "code", "execution_count": 25, "id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "from PIL import Image\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "c28cc123-71be-47ff-b78f-3a4d5592df39", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Maximum pixel value in the dataset: 107.49169921875\n" ] } ], "source": [ "\n", "max_pixel_value = 107.49169921875\n", "\n", "print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint before Generator is OK\n" ] } ], "source": [ "class NO2Dataset(Dataset):\n", " \n", " def __init__(self, image_dir, mask_dir):\n", " \n", " self.image_dir = image_dir\n", " self.mask_dir = mask_dir\n", " self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n", " self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n", " \n", " def __len__(self):\n", " \n", " return len(self.image_filenames)\n", " \n", " def __getitem__(self, idx):\n", " \n", " image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n", " mask_idx = np.random.choice(self.mask_filenames)\n", " mask_path = os.path.join(self.mask_dir, mask_idx)\n", "\n", " # 加载图像数据 (.npy 文件)\n", " image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n", "\n", " # 加载掩码数据 (.jpg 文件)\n", " mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n", "\n", " # 将掩码数据中非0值设为1,0值保持不变\n", " mask = np.where(mask != 0, 1.0, 0.0)\n", "\n", " # 保持掩码数据形状为 (96, 96, 1)\n", " mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n", "\n", " # 应用掩码\n", " masked_image = image.copy()\n", " masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n", "\n", " # cGAN的输入和目标\n", " X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n", " y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n", "\n", " # 转换形状为 (channels, height, width)\n", " X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n", "\n", " return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n", "\n", "# 实例化数据集和数据加载器\n", "image_dir = './out_mat/96/train/'\n", "mask_dir = './out_mat/96/mask/40/'\n", "\n", "print(f\"checkpoint before Generator is OK\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "41da7319-9795-441d-bde8-8cf390365099", "metadata": {}, "outputs": [], "source": [ "dataset = NO2Dataset(image_dir, mask_dir)\n", "dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n", "val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n", "val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n", "test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n", "test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 5, "id": "70797703-1619-4be7-b965-5506b3d1e775", "metadata": {}, "outputs": [], "source": [ "# 可视化特定特征的函数\n", "def visualize_feature(input_feature,masked_feature, output_feature, title):\n", " plt.figure(figsize=(12, 6))\n", " plt.subplot(1, 3, 1)\n", " plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Input\")\n", " plt.subplot(1, 3, 2)\n", " plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Masked\")\n", " plt.subplot(1, 3, 3)\n", " plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Recovery\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "id": "645114e8-65a4-4867-b3fe-23395288e855", "metadata": {}, "outputs": [], "source": [ "class Conv(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n", " super(Conv, self).__init__(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n", " dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "id": "2af52d0e-b785-4a84-838c-6fcfe2568722", "metadata": {}, "outputs": [], "source": [ "class ConvBNReLU(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n", " bias=False):\n", " super(ConvBNReLU, self).__init__(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n", " dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n", " norm_layer(out_channels),\n", " nn.ReLU()\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "id": "31ecf247-e98b-4977-a145-782914a042bd", "metadata": {}, "outputs": [], "source": [ "class SeparableBNReLU(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n", " super(SeparableBNReLU, self).__init__(\n", " nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n", " padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n", " # 分离卷积,仅调整空间信息\n", " norm_layer(in_channels), # 对输入通道进行归一化\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n", " nn.ReLU6()\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9", "metadata": {}, "outputs": [], "source": [ "class ResidualBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n", " super(ResidualBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(out_channels)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(out_channels)\n", "\n", " # 如果输入和输出通道不一致,进行降采样操作\n", " self.downsample = downsample\n", " if in_channels != out_channels or stride != 1:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " def forward(self, x):\n", " identity = x\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " out += identity\n", " out = self.relu(out)\n", " return out\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "7853bf62-02f5-4917-b950-6fdfe467df4a", "metadata": {}, "outputs": [], "source": [ "class Mlp(nn.Module):\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n", " super().__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n", "\n", " self.act = act_layer()\n", " self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n", " self.drop = nn.Dropout(drop, inplace=True)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 11, "id": "e2375881-a11b-47a7-8f56-2eadb25010b0", "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttentionBlock(nn.Module):\n", " def __init__(self, embed_dim, num_heads, dropout=0.1):\n", " super(MultiHeadAttentionBlock, self).__init__()\n", " self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n", " self.norm = nn.LayerNorm(embed_dim)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n", " B, C, H, W = x.shape\n", " x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n", "\n", " # Apply multihead attention\n", " attn_output, _ = self.attention(x, x, x)\n", "\n", " # Apply normalization and dropout\n", " attn_output = self.norm(attn_output)\n", " attn_output = self.dropout(attn_output)\n", "\n", " # Reshape back to (B, C, H, W)\n", " attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n", "\n", " return attn_output" ] }, { "cell_type": "code", "execution_count": 12, "id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384", "metadata": {}, "outputs": [], "source": [ "class SpatialAttentionBlock(nn.Module):\n", " def __init__(self):\n", " super(SpatialAttentionBlock, self).__init__()\n", " self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n", "\n", " def forward(self, x): #(B, 64, H, W)\n", " avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n", " max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n", " out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n", " out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n", " return x * out #(B, C, H, W)" ] }, { "cell_type": "code", "execution_count": 13, "id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9", "metadata": {}, "outputs": [], "source": [ "class DecoderAttentionBlock(nn.Module):\n", " def __init__(self, in_channels):\n", " super(DecoderAttentionBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n", " self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n", " self.spatial_attention = SpatialAttentionBlock()\n", "\n", " def forward(self, x):\n", " # 通道注意力\n", " b, c, h, w = x.size()\n", " avg_pool = F.adaptive_avg_pool2d(x, 1)\n", " max_pool = F.adaptive_max_pool2d(x, 1)\n", "\n", " avg_out = self.conv1(avg_pool)\n", " max_out = self.conv1(max_pool)\n", "\n", " out = avg_out + max_out\n", " out = torch.sigmoid(self.conv2(out))\n", "\n", " # 添加空间注意力\n", " out = x * out\n", " out = self.spatial_attention(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 14, "id": "15b9d453-d8d9-43b8-aca2-904735fb3a99", "metadata": {}, "outputs": [], "source": [ "class SEBlock(nn.Module):\n", " def __init__(self, in_channels, reduced_dim):\n", " super(SEBlock, self).__init__()\n", " self.se = nn.Sequential(\n", " nn.AdaptiveAvgPool2d(1), # 全局平均池化\n", " nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n", " nn.ReLU(),\n", " nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n", " nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n", " )\n", "\n", " def forward(self, x):\n", " return x * self.se(x)" ] }, { "cell_type": "code", "execution_count": 15, "id": "6379adb7-8a87-4dd8-a695-4013a7b37830", "metadata": { "tags": [] }, "outputs": [], "source": [ "# 定义Masked Autoencoder模型\n", "class MaskedAutoencoder(nn.Module):\n", " def __init__(self):\n", " super(MaskedAutoencoder, self).__init__()\n", " self.encoder = nn.Sequential(\n", " Conv(1, 32, kernel_size=3, stride=2),\n", " \n", " nn.ReLU(),\n", " \n", " SEBlock(32,32),\n", " \n", " ConvBNReLU(32, 64, kernel_size=3, stride=2),\n", " \n", " ResidualBlock(64,64),\n", " \n", " SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n", " \n", " MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n", " \n", " SEBlock(128, 128)\n", " )\n", " self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " DecoderAttentionBlock(32),\n", " nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " DecoderAttentionBlock(16),\n", " nn.ReLU(),\n", " \n", " nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "# 实例化模型、损失函数和优化器\n", "model = MaskedAutoencoder()\n", "criterion = nn.MSELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": 16, "id": "e9c804e0-6f5c-40a7-aba7-a03a496cf427", "metadata": {}, "outputs": [], "source": [ "def masked_mse_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 17, "id": "404a8bfb-4976-4cce-b989-c5e401bce0d7", "metadata": {}, "outputs": [], "source": [ "# 训练函数\n", "def train_epoch(model, device, data_loader, criterion, optimizer):\n", " model.train()\n", " running_loss = 0.0\n", " for batch_idx, (X, y, mask) in enumerate(data_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " optimizer.zero_grad()\n", " reconstructed = model(X)\n", " loss = masked_mse_loss(reconstructed, y, mask)\n", " # loss = criterion(reconstructed, y)\n", " loss.backward()\n", " optimizer.step()\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16", "metadata": {}, "outputs": [], "source": [ "# 评估函数\n", "def evaluate(model, device, data_loader, criterion):\n", " model.eval()\n", " running_loss = 0.0\n", " with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(data_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " reconstructed = model(X)\n", " if batch_idx == 8:\n", " rand_ind = np.random.randint(0, len(y))\n", " # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n", " loss = masked_mse_loss(reconstructed, y, mask)\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": 19, "id": "296ba6bd-2239-4948-b278-7edcb29bfd14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "# 数据准备\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "743d1000-561e-4444-8b49-88346c14f28b", "metadata": {}, "outputs": [], "source": [ "model = model.to(device)\n", "\n", "num_epochs = 100\n", "train_losses = list()\n", "val_losses = list()\n", "for epoch in range(num_epochs):\n", " train_loss = train_epoch(model, device, dataloader, criterion, optimizer)\n", " train_losses.append(train_loss)\n", " val_loss = evaluate(model, device, val_loader, criterion)\n", " val_losses.append(val_loss)\n", " print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n", "\n", "# 测试模型\n", "test_loss = evaluate(model, device, test_loader, criterion)\n", "print(f'Test Loss: {test_loss}')" ] }, { "cell_type": "code", "execution_count": 38, "id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'train_losses' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[38], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tr_ind \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(\u001b[43mtrain_losses\u001b[49m)))\n\u001b[1;32m 2\u001b[0m val_ind \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(val_losses)))\n\u001b[1;32m 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(train_losses[\u001b[38;5;241m1\u001b[39m:], label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "\u001b[0;31mNameError\u001b[0m: name 'train_losses' is not defined" ] } ], "source": [ "tr_ind = list(range(len(train_losses)))\n", "val_ind = list(range(len(val_losses)))\n", "plt.plot(train_losses[1:], label='train_loss')\n", "plt.plot(val_losses[1:], label='val_loss')\n", "plt.legend(loc='best')" ] }, { "cell_type": "code", "execution_count": null, "id": "849b1706-1a98-4571-989f-da06d949c843", "metadata": {}, "outputs": [], "source": [ "torch.save(model, './models/MAE/final_40.pt')" ] }, { "cell_type": "code", "execution_count": 20, "id": "40a803b2-4891-4d47-ab61-cf88db8007a0", "metadata": {}, "outputs": [], "source": [ "model = torch.load('./models/MAE/final_40.pt')" ] }, { "cell_type": "code", "execution_count": 21, "id": "a8467686-0655-4056-8e01-56299eb89d7c", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 22, "id": "016c3045-0312-462f-82ae-7272944ed92d", "metadata": {}, "outputs": [], "source": [ "def cal_ioa(y_true, y_pred):\n", " # 计算平均值\n", " mean_observed = np.mean(y_true)\n", " mean_predicted = np.mean(y_pred)\n", "\n", " # 计算IoA\n", " numerator = np.sum((y_true - y_pred) ** 2)\n", " denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n", " IoA = 1 - (numerator / denominator)\n", "\n", " return IoA" ] }, { "cell_type": "code", "execution_count": null, "id": "dae7427e-548e-4276-a4ea-bc9b279d44e8", "metadata": {}, "outputs": [], "source": [ "eva_list = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * max_pixel_value\n", " rev_recon = reconstructed * max_pixel_value\n", " # todo: 这里需要只评估修补出来的模块\n", " data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n", " data_label = data_label[mask_rev==1]\n", " recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n", " recon_no2 = recon_no2[mask_rev==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " r = np.corrcoef(data_label, recon_no2)[0, 1]\n", " eva_list.append([mae, rmse, mape, r2, ioa, r])" ] }, { "cell_type": "code", "execution_count": null, "id": "d5b01834-ca18-4ec3-bc9d-64382d0fab34", "metadata": {}, "outputs": [], "source": [ "pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()" ] }, { "cell_type": "code", "execution_count": 23, "id": "0887481a-764e-4fd5-9580-45aa813a4391", "metadata": {}, "outputs": [], "source": [ "eva_list_frame = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "best_mape = 1\n", "best_img = None\n", "best_mask = None\n", "best_recov = None\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * max_pixel_value\n", " rev_recon = reconstructed * max_pixel_value\n", " # todo: 这里需要只评估修补出来的模块\n", " for i, sample in enumerate(rev_data):\n", " used_mask = mask_rev[i]\n", " data_label = sample[0] * used_mask\n", " recon_no2 = rev_recon[i][0] * used_mask\n", " data_label = data_label[used_mask==1]\n", " recon_no2 = recon_no2[used_mask==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " r = np.corrcoef(data_label, recon_no2)[0, 1]\n", " eva_list_frame.append([mae, rmse, mape, r2, ioa, r])\n", " if mape < best_mape:\n", " best_recov = rev_recon[i][0].numpy()\n", " best_mask = used_mask.numpy()\n", " best_img = sample[0].numpy()\n", " best_mape = mape" ] }, { "cell_type": "code", "execution_count": 26, "id": "f7355895-ffde-458f-b4e6-b8afd95ea663", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | mae | \n", "rmse | \n", "mape | \n", "r2 | \n", "ioa | \n", "r | \n", "
---|---|---|---|---|---|---|
count | \n", "4739.000000 | \n", "4739.000000 | \n", "4739.000000 | \n", "4739.000000 | \n", "4739.000000 | \n", "4739.000000 | \n", "
mean | \n", "1.540401 | \n", "2.199879 | \n", "0.195554 | \n", "0.585799 | \n", "0.848016 | \n", "0.778401 | \n", "
std | \n", "0.647315 | \n", "0.909418 | \n", "0.092239 | \n", "0.213993 | \n", "0.106987 | \n", "0.127430 | \n", "
min | \n", "0.462070 | \n", "0.593854 | \n", "0.068942 | \n", "-0.551587 | \n", "0.218504 | \n", "0.145717 | \n", "
25% | \n", "1.021385 | \n", "1.472757 | \n", "0.144170 | \n", "0.460184 | \n", "0.805011 | \n", "0.711952 | \n", "
50% | \n", "1.367836 | \n", "2.056208 | \n", "0.176119 | \n", "0.624006 | \n", "0.876770 | \n", "0.805993 | \n", "
75% | \n", "1.975825 | \n", "2.792777 | \n", "0.217419 | \n", "0.745375 | \n", "0.923944 | \n", "0.871612 | \n", "
max | \n", "5.186517 | \n", "9.158884 | \n", "0.960081 | \n", "0.968376 | \n", "0.992196 | \n", "0.985054 | \n", "