{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "from PIL import Image\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "id": "15b9ced8-7282-4f97-a079-f31bf9405145", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.random.seed(0)\n", "torch.random.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": 3, "id": "7f83e6c7-8207-41b3-908b-6b1fad78ecd5", "metadata": {}, "outputs": [], "source": [ "max_pixel_value = 107.49169921875" ] }, { "cell_type": "code", "execution_count": 4, "id": "c66f2b9f-fcad-4237-abb2-d7f918d74116", "metadata": {}, "outputs": [], "source": [ "class NO2Dataset(Dataset):\n", " \n", " def __init__(self, image_dir, mask_dir):\n", " \n", " self.image_dir = image_dir\n", " self.mask_dir = mask_dir\n", " self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n", " self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n", " \n", " def __len__(self):\n", " \n", " return len(self.image_filenames)\n", " \n", " def __getitem__(self, idx):\n", " \n", " image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n", " mask_idx = np.random.choice(self.mask_filenames)\n", " mask_path = os.path.join(self.mask_dir, mask_idx)\n", "\n", " # 加载图像数据 (.npy 文件)\n", " image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n", "\n", " # 加载掩码数据 (.jpg 文件)\n", " mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n", "\n", " # 将掩码数据中非0值设为1,0值保持不变\n", " mask = np.where(mask != 0, 1.0, 0.0)\n", "\n", " # 保持掩码数据形状为 (96, 96, 1)\n", " mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n", "\n", " # 应用掩码\n", " masked_image = image.copy()\n", " masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n", "\n", " # cGAN的输入和目标\n", " X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n", " y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n", "\n", " # 转换形状为 (channels, height, width)\n", " X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n", "\n", " return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n", "\n", "# 实例化数据集和数据加载器\n", "image_dir = './out_mat/96/train/'\n", "mask_dir = './out_mat/96/mask/20/'" ] }, { "cell_type": "code", "execution_count": 5, "id": "e3354304-f6de-44bf-adbf-bbff557a8c93", "metadata": {}, "outputs": [], "source": [ "train_set = NO2Dataset(image_dir, mask_dir)\n", "train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n", "val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n", "val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n", "test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n", "test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 6, "id": "70797703-1619-4be7-b965-5506b3d1e775", "metadata": {}, "outputs": [], "source": [ "# 可视化特定特征的函数\n", "def visualize_feature(input_feature,masked_feature, output_feature, title):\n", " plt.figure(figsize=(12, 6))\n", " plt.subplot(1, 3, 1)\n", " plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Input\")\n", " plt.subplot(1, 3, 2)\n", " plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Masked\")\n", " plt.subplot(1, 3, 3)\n", " plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n", " plt.title(title + \" Recovery\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "id": "645114e8-65a4-4867-b3fe-23395288e855", "metadata": {}, "outputs": [], "source": [ "class Conv(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n", " super(Conv, self).__init__(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n", " dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "id": "2af52d0e-b785-4a84-838c-6fcfe2568722", "metadata": {}, "outputs": [], "source": [ "class ConvBNReLU(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n", " bias=False):\n", " super(ConvBNReLU, self).__init__(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n", " dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n", " norm_layer(out_channels),\n", " nn.ReLU()\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "id": "31ecf247-e98b-4977-a145-782914a042bd", "metadata": {}, "outputs": [], "source": [ "class SeparableBNReLU(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n", " super(SeparableBNReLU, self).__init__(\n", " nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n", " padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n", " # 分离卷积,仅调整空间信息\n", " norm_layer(in_channels), # 对输入通道进行归一化\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n", " nn.ReLU6()\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9", "metadata": {}, "outputs": [], "source": [ "class ResidualBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n", " super(ResidualBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(out_channels)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(out_channels)\n", "\n", " # 如果输入和输出通道不一致,进行降采样操作\n", " self.downsample = downsample\n", " if in_channels != out_channels or stride != 1:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " def forward(self, x):\n", " identity = x\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " out += identity\n", " out = self.relu(out)\n", " return out\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "7853bf62-02f5-4917-b950-6fdfe467df4a", "metadata": {}, "outputs": [], "source": [ "class Mlp(nn.Module):\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n", " super().__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n", "\n", " self.act = act_layer()\n", " self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n", " self.drop = nn.Dropout(drop, inplace=True)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 12, "id": "e2375881-a11b-47a7-8f56-2eadb25010b0", "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttentionBlock(nn.Module):\n", " def __init__(self, embed_dim, num_heads, dropout=0.1):\n", " super(MultiHeadAttentionBlock, self).__init__()\n", " self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n", " self.norm = nn.LayerNorm(embed_dim)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n", " B, C, H, W = x.shape\n", " x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n", "\n", " # Apply multihead attention\n", " attn_output, _ = self.attention(x, x, x)\n", "\n", " # Apply normalization and dropout\n", " attn_output = self.norm(attn_output)\n", " attn_output = self.dropout(attn_output)\n", "\n", " # Reshape back to (B, C, H, W)\n", " attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n", "\n", " return attn_output" ] }, { "cell_type": "code", "execution_count": 13, "id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384", "metadata": {}, "outputs": [], "source": [ "class SpatialAttentionBlock(nn.Module):\n", " def __init__(self):\n", " super(SpatialAttentionBlock, self).__init__()\n", " self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n", "\n", " def forward(self, x): #(B, 64, H, W)\n", " avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n", " max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n", " out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n", " out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n", " return x * out #(B, C, H, W)" ] }, { "cell_type": "code", "execution_count": 14, "id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9", "metadata": {}, "outputs": [], "source": [ "class DecoderAttentionBlock(nn.Module):\n", " def __init__(self, in_channels):\n", " super(DecoderAttentionBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n", " self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n", " self.spatial_attention = SpatialAttentionBlock()\n", "\n", " def forward(self, x):\n", " # 通道注意力\n", " b, c, h, w = x.size()\n", " avg_pool = F.adaptive_avg_pool2d(x, 1)\n", " max_pool = F.adaptive_max_pool2d(x, 1)\n", "\n", " avg_out = self.conv1(avg_pool)\n", " max_out = self.conv1(max_pool)\n", "\n", " out = avg_out + max_out\n", " out = torch.sigmoid(self.conv2(out))\n", "\n", " # 添加空间注意力\n", " out = x * out\n", " out = self.spatial_attention(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 15, "id": "15b9d453-d8d9-43b8-aca2-904735fb3a99", "metadata": {}, "outputs": [], "source": [ "class SEBlock(nn.Module):\n", " def __init__(self, in_channels, reduced_dim):\n", " super(SEBlock, self).__init__()\n", " self.se = nn.Sequential(\n", " nn.AdaptiveAvgPool2d(1), # 全局平均池化\n", " nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n", " nn.ReLU(),\n", " nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n", " nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n", " )\n", "\n", " def forward(self, x):\n", " return x * self.se(x)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a382ed1b-cc88-4f03-95c2-843981ee81f1", "metadata": {}, "outputs": [], "source": [ "def masked_mse_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 17, "id": "6379adb7-8a87-4dd8-a695-4013a7b37830", "metadata": { "tags": [] }, "outputs": [], "source": [ "# 定义Masked Autoencoder模型\n", "class MaskedAutoencoder(nn.Module):\n", " def __init__(self):\n", " super(MaskedAutoencoder, self).__init__()\n", " self.encoder = nn.Sequential(\n", " Conv(1, 32, kernel_size=3, stride=2),\n", " nn.ReLU(),\n", " SEBlock(32,32),\n", " ConvBNReLU(32, 64, kernel_size=3, stride=2),\n", " ResidualBlock(64,64),\n", " SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n", " MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n", " SEBlock(128, 128)\n", " )\n", " self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " # DecoderAttentionBlock(32),\n", " nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " # DecoderAttentionBlock(16),\n", " nn.ReLU(),\n", " \n", " nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "# 实例化模型、损失函数和优化器\n", "model = MaskedAutoencoder()\n", "criterion = nn.MSELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": 18, "id": "404a8bfb-4976-4cce-b989-c5e401bce0d7", "metadata": {}, "outputs": [], "source": [ "# 训练函数\n", "def train_epoch(model, device, data_loader, criterion, optimizer):\n", " model.train()\n", " running_loss = 0.0\n", " for batch_idx, (X, y, mask) in enumerate(data_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " optimizer.zero_grad()\n", " reconstructed = model(X)\n", " # loss = criterion(reconstructed, y)\n", " loss = masked_mse_loss(reconstructed, y, mask)\n", " loss.backward()\n", " optimizer.step()\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": 19, "id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16", "metadata": {}, "outputs": [], "source": [ "# 评估函数\n", "def evaluate(model, device, data_loader, criterion):\n", " model.eval()\n", " running_loss = 0.0\n", " with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(data_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " reconstructed = model(X)\n", " if batch_idx == 8:\n", " rand_ind = np.random.randint(0, len(y))\n", " # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n", " # loss = criterion(reconstructed, y)\n", " loss = masked_mse_loss(reconstructed, y, mask)\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": 20, "id": "6094b6c8-8211-4557-9944-7eef977ea9ec", "metadata": {}, "outputs": [], "source": [ "def masked_mae_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 29, "id": "781f558e-d41c-4721-94fd-564cd6c2b347", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "# 数据准备\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 30, "id": "743d1000-561e-4444-8b49-88346c14f28b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Train Loss: 0.013549078723781131, Val Loss: 0.014539383435204847\n", "Epoch 2, Train Loss: 0.013641111095966192, Val Loss: 0.014635173200782555\n", "Epoch 3, Train Loss: 0.013503858572290988, Val Loss: 0.01476309893291388\n", "Epoch 4, Train Loss: 0.013455510417970887, Val Loss: 0.014315864057349624\n", "Epoch 5, Train Loss: 0.01339626228704193, Val Loss: 0.01442837900023407\n", "Epoch 6, Train Loss: 0.013295360569035608, Val Loss: 0.015184532503472336\n", "Epoch 12, Train Loss: 0.012901031857793125, Val Loss: 0.013935101566030018\n", "Epoch 13, Train Loss: 0.01295265725158761, Val Loss: 0.013862666924164366\n", "Epoch 14, Train Loss: 0.013010161795149865, Val Loss: 0.013880979492148357\n", "Epoch 15, Train Loss: 0.012936625905940977, Val Loss: 0.013813913021403463\n", "Epoch 16, Train Loss: 0.01287072714926167, Val Loss: 0.01403502803017844\n", "Epoch 17, Train Loss: 0.012832806871214695, Val Loss: 0.014388528165977393\n", "Epoch 18, Train Loss: 0.012794200125992583, Val Loss: 0.01383661480147892\n", "Epoch 19, Train Loss: 0.01294981115208003, Val Loss: 0.01408140508652623\n", "Epoch 20, Train Loss: 0.012662894464583631, Val Loss: 0.01359965718949019\n", "Test Loss: 0.007365767304242279\n" ] } ], "source": [ "model = model.to(device)\n", "\n", "num_epochs = 20\n", "train_losses = list()\n", "val_losses = list()\n", "for epoch in range(num_epochs):\n", " train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n", " train_losses.append(train_loss)\n", " val_loss = evaluate(model, device, val_loader, criterion)\n", " val_losses.append(val_loss)\n", " print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n", "\n", "# 测试模型\n", "test_loss = evaluate(model, device, test_loader, criterion)\n", "print(f'Test Loss: {test_loss}')" ] }, { "cell_type": "code", "execution_count": 31, "id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tr_ind = list(range(len(train_losses)))\n", "val_ind = list(range(len(val_losses)))\n", "plt.plot(train_losses, label='train_loss')\n", "plt.plot(val_losses, label='val_loss')\n", "plt.legend(loc='best')" ] }, { "cell_type": "code", "execution_count": 32, "id": "1f48acd7-70e8-46db-9148-6a2df3153f08", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 33, "id": "313fa420-c856-4db1-80ae-b543e1fb73ef", "metadata": {}, "outputs": [], "source": [ "eva_list = list()\n", "model = model.to('cpu')\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * max_pixel_value\n", " rev_recon = reconstructed * max_pixel_value\n", " # todo: 这里需要只评估修补出来的模块\n", " data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n", " data_label = data_label[mask_rev==1]\n", " recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n", " recon_no2 = recon_no2[mask_rev==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " eva_list.append([mae, rmse, mape, r2])" ] }, { "cell_type": "code", "execution_count": 34, "id": "5c6d5e5a-90f6-4e9a-882f-c2f160b0cb15", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2
count75.00000075.00000075.00000075.000000
mean1.2969062.0223620.1676940.904339
std0.0757610.1370410.0131710.010395
min1.1212841.7162750.1436670.875878
25%1.2383781.9179070.1564290.898060
50%1.2871932.0118280.1666790.904941
75%1.3530452.1024090.1769960.911137
max1.4460462.4145320.2021420.924070
\n", "
" ], "text/plain": [ " mae rmse mape r2\n", "count 75.000000 75.000000 75.000000 75.000000\n", "mean 1.296906 2.022362 0.167694 0.904339\n", "std 0.075761 0.137041 0.013171 0.010395\n", "min 1.121284 1.716275 0.143667 0.875878\n", "25% 1.238378 1.917907 0.156429 0.898060\n", "50% 1.287193 2.011828 0.166679 0.904941\n", "75% 1.353045 2.102409 0.176996 0.911137\n", "max 1.446046 2.414532 0.202142 0.924070" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe()" ] }, { "cell_type": "code", "execution_count": 35, "id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e", "metadata": {}, "outputs": [], "source": [ "def cal_ioa(y_true, y_pred):\n", " # 计算平均值\n", " mean_observed = np.mean(y_true)\n", " mean_predicted = np.mean(y_pred)\n", "\n", " # 计算IoA\n", " numerator = np.sum((y_true - y_pred) ** 2)\n", " denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n", " IoA = 1 - (numerator / denominator)\n", "\n", " return IoA" ] }, { "cell_type": "code", "execution_count": 36, "id": "b4250d45-b430-40a0-ace7-f59d3451aebd", "metadata": {}, "outputs": [], "source": [ "eva_list_frame = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * max_pixel_value\n", " rev_recon = reconstructed * max_pixel_value\n", " # todo: 这里需要只评估修补出来的模块\n", " for i, sample in enumerate(rev_data):\n", " used_mask = mask_rev[i]\n", " data_label = sample[0] * used_mask\n", " recon_no2 = rev_recon[i][0] * used_mask\n", " data_label = data_label[used_mask==1]\n", " recon_no2 = recon_no2[used_mask==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " r = np.corrcoef(data_label, recon_no2)[0, 1]\n", " eva_list_frame.append([mae, rmse, mape, r2, ioa, r])" ] }, { "cell_type": "code", "execution_count": 37, "id": "039d0041-4573-4645-aeb0-686eabfe8b6f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioar
count4739.0000004739.0000004739.0000004739.0000004739.0000004739.000000
mean1.3068171.8458190.1668760.6705190.8866460.836323
std0.6236450.9026190.1070250.2407520.1111420.121726
min0.4329910.5683190.050612-1.539424-0.2675690.022258
25%0.8355791.1723220.1133020.5837130.8647560.794922
50%1.1617101.6581950.1433860.7358600.9213410.869860
75%1.6173822.2997310.1850390.8272420.9512850.916741
max5.3382309.9369511.9299860.9832080.9957670.992588
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa \\\n", "count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n", "mean 1.306817 1.845819 0.166876 0.670519 0.886646 \n", "std 0.623645 0.902619 0.107025 0.240752 0.111142 \n", "min 0.432991 0.568319 0.050612 -1.539424 -0.267569 \n", "25% 0.835579 1.172322 0.113302 0.583713 0.864756 \n", "50% 1.161710 1.658195 0.143386 0.735860 0.921341 \n", "75% 1.617382 2.299731 0.185039 0.827242 0.951285 \n", "max 5.338230 9.936951 1.929986 0.983208 0.995767 \n", "\n", " r \n", "count 4739.000000 \n", "mean 0.836323 \n", "std 0.121726 \n", "min 0.022258 \n", "25% 0.794922 \n", "50% 0.869860 \n", "75% 0.916741 \n", "max 0.992588 " ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()" ] }, { "cell_type": "code", "execution_count": null, "id": "83c7e465-bbd0-4c56-8cb4-9d1122fe695f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }