{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "fa295d87-946f-402b-9d97-1127ee9a33a0", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "from PIL import Image\n", "\n", "MAX_VALUE = 107.49169921875" ] }, { "cell_type": "code", "execution_count": 2, "id": "c6dd8e35-02e3-491c-b4be-a874cf1054ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 3, "id": "2f151caf-43d1-4d59-a111-96ad5e6bc38b", "metadata": {}, "outputs": [], "source": [ "class GrayScaleDataset(Dataset):\n", " def __init__(self, data_dir):\n", " self.data_dir = data_dir\n", " self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n", "\n", " def __len__(self):\n", " return len(self.file_list)\n", "\n", " def __getitem__(self, idx):\n", " file_path = os.path.join(self.data_dir, self.file_list[idx])\n", " data = np.load(file_path)[:,:,0] / MAX_VALUE\n", " return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3", "metadata": {}, "outputs": [], "source": [ "class NO2Dataset(Dataset):\n", " \n", " def __init__(self, image_dir, mask_dir):\n", " \n", " self.image_dir = image_dir\n", " self.mask_dir = mask_dir\n", " self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n", " self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n", " \n", " def __len__(self):\n", " \n", " return len(self.image_filenames)\n", " \n", " def __getitem__(self, idx):\n", " \n", " image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n", " mask_idx = idx % len(self.mask_filenames)\n", " mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n", "\n", " # 加载图像数据 (.npy 文件)\n", " image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n", "\n", " # 加载掩码数据 (.jpg 文件)\n", " mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n", "\n", " # 将掩码数据中非0值设为1,0值保持不变\n", " mask = np.where(mask != 0, 1.0, 0.0)\n", "\n", " # 保持掩码数据形状为 (96, 96, 1)\n", " mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n", "\n", " # 应用掩码\n", " masked_image = image.copy()\n", " masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n", "\n", " # cGAN的输入和目标\n", " X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n", " y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n", "\n", " # 转换形状为 (channels, height, width)\n", " X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n", "\n", " return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 5, "id": "36752a6d-329a-464d-a329-f02206bf63b0", "metadata": {}, "outputs": [], "source": [ "class PatchMasking:\n", " def __init__(self, patch_size, mask_ratio):\n", " self.patch_size = patch_size\n", " self.mask_ratio = mask_ratio\n", "\n", " def __call__(self, x):\n", " batch_size, C, H, W = x.shape\n", " num_patches = (H // self.patch_size) * (W // self.patch_size)\n", " num_masked = int(num_patches * self.mask_ratio)\n", " \n", " # 为每个样本生成独立的mask\n", " masks = []\n", " for _ in range(batch_size):\n", " mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n", " mask[:num_masked] = 1\n", " mask = mask[torch.randperm(num_patches)]\n", " mask = mask.view(H // self.patch_size, W // self.patch_size)\n", " mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n", " masks.append(mask)\n", " \n", " # 将所有mask堆叠成一个批量张量\n", " masks = torch.stack(masks, dim=0)\n", " masks = torch.unsqueeze(masks, dim=1)\n", " \n", " # 应用mask到输入x上\n", " masked_x = x * (1 - masks.float())\n", " return masked_x, masks" ] }, { "cell_type": "code", "execution_count": 12, "id": "0db0d920-8de2-4bad-9b99-67eed152644d", "metadata": {}, "outputs": [], "source": [ "class Mlp(nn.Module):\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n", " super().__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n", "\n", " self.act = act_layer()\n", " self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n", " self.drop = nn.Dropout(drop, inplace=True)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "id": "cb27d3a7-77ed-4110-96bd-bcc4880964d2", "metadata": {}, "outputs": [], "source": [ "class ViTEncoder(nn.Module):\n", " def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):\n", " super(ViTEncoder, self).__init__()\n", " self.patch_size = patch_size\n", " self.dim = dim\n", " self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n", " \n", " # 定义 Transformer 编码器层\n", " encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)\n", " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n", "\n", " def forward(self, x):\n", " x = self.patch_embedding(x)\n", " x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n", " x = self.transformer_encoder(x)\n", " return x\n", "\n", "class ConvDecoder(nn.Module):\n", " def __init__(self, dim=128, patch_size=8, img_size=96):\n", " super(ConvDecoder, self).__init__()\n", " self.dim = dim\n", " self.patch_size = patch_size\n", " self.img_size = img_size\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)\n", " )\n", "\n", " def forward(self, x):\n", " # x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n", " x = self.decoder(x)\n", " return x\n", "\n", "class MAEModel(nn.Module):\n", " def __init__(self, encoder, decoder):\n", " super(MAEModel, self).__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", "\n", " def forward(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded" ] }, { "cell_type": "code", "execution_count": 7, "id": "783e62af-7f6a-40bd-a423-be63fe98a655", "metadata": {}, "outputs": [], "source": [ "def masked_mse_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 8, "id": "baeffdf0-cdc2-44c4-972a-e2e671635d6a", "metadata": {}, "outputs": [], "source": [ "def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n", " model.to(device)\n", " for epoch in range(epochs):\n", " model.train()\n", " train_loss = 0\n", " for data in train_loader:\n", " data = data.to(device)\n", " optimizer.zero_grad()\n", " masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " loss = masked_mse_loss(output, data, mask)\n", " loss.backward()\n", " optimizer.step()\n", " train_loss += loss.item()\n", " train_loss /= len(train_loader)\n", "\n", " model.eval()\n", " val_loss = 0\n", " with torch.no_grad():\n", " for data in val_loader:\n", " data = data.to(device)\n", " masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " loss = masked_mse_loss(output, data, mask)\n", " val_loss += loss.item()\n", " val_loss /= len(val_loader)\n", "\n", " print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')" ] }, { "cell_type": "code", "execution_count": 9, "id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b", "metadata": {}, "outputs": [], "source": [ "train_dir = './out_mat/96/train/'\n", "train_dataset = GrayScaleDataset(train_dir)\n", "\n", "val_dir = './out_mat/96/valid/'\n", "val_dataset = GrayScaleDataset(val_dir)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 13, "id": "7d6d07a4-31f1-4350-a487-b583db979381", "metadata": {}, "outputs": [], "source": [ "encoder = ViTEncoder()\n", "decoder = ConvDecoder()\n", "model = MAEModel(encoder, decoder)\n", "criterion = nn.MSELoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": 19, "id": "8ee33651-f5f0-4b92-96e9-a84e32725b44", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 128, 6, 6])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.transpose(1, 2).reshape(-1, 128, 6, 6).shape" ] }, { "cell_type": "code", "execution_count": 15, "id": "a5684758-bc6d-45b0-b885-da37820ca5ac", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m 2\u001b[0m a \u001b[38;5;241m=\u001b[39m encoder(i)\n\u001b[0;32m----> 3\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m c \u001b[38;5;241m=\u001b[39m decoder(b)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "Cell \u001b[0;32mIn[12], line 13\u001b[0m, in \u001b[0;36mMlp.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 13\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(x)\n\u001b[1;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrop(x)\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 454\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 455\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead" ] } ], "source": [ "for i in train_loader:\n", " a = encoder(i)\n", " b = model.mlp(a)\n", " c = decoder(b)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "09b04e16-3257-4890-b736-a6c7274561e0", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "train_model(model, train_loader, val_loader, epochs=100, criterion=criterion, optimizer=optimizer, device=device)" ] }, { "cell_type": "code", "execution_count": 19, "id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635", "metadata": {}, "outputs": [], "source": [ "test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n", "test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 20, "id": "56653f37-899a-47d6-8d50-e456b4ad1835", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 21, "id": "f1ecbd05-7aa3-43ae-8bc2-aa44d19689b9", "metadata": {}, "outputs": [], "source": [ "def cal_ioa(y_true, y_pred):\n", " # 计算平均值\n", " mean_observed = np.mean(y_true)\n", " mean_predicted = np.mean(y_pred)\n", "\n", " # 计算IoA\n", " numerator = np.sum((y_true - y_pred) ** 2)\n", " denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n", " IoA = 1 - (numerator / denominator)\n", "\n", " return IoA" ] }, { "cell_type": "code", "execution_count": 22, "id": "e840b789-bf68-4b4d-a8d3-c5362c310349", "metadata": {}, "outputs": [], "source": [ "eva_list = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * MAX_VALUE\n", " rev_recon = reconstructed * MAX_VALUE\n", " # todo: 这里需要只评估修补出来的模块\n", " data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n", " data_label = data_label[mask_rev==1]\n", " recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n", " recon_no2 = recon_no2[mask_rev==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " eva_list.append([mae, rmse, mape, r2, ioa])" ] }, { "cell_type": "code", "execution_count": 23, "id": "41fa754d-1eee-43a2-9e39-a0254719be30", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | mae | \n", "rmse | \n", "mape | \n", "r2 | \n", "ioa | \n", "
---|---|---|---|---|---|
count | \n", "149.000000 | \n", "149.000000 | \n", "149.000000 | \n", "149.000000 | \n", "149.000000 | \n", "
mean | \n", "7.068207 | \n", "9.016465 | \n", "0.814727 | \n", "-0.952793 | \n", "0.564749 | \n", "
std | \n", "0.659118 | \n", "0.774556 | \n", "0.054147 | \n", "0.162851 | \n", "0.033048 | \n", "
min | \n", "5.609327 | \n", "7.113544 | \n", "0.599120 | \n", "-1.402735 | \n", "0.461420 | \n", "
25% | \n", "6.613351 | \n", "8.499699 | \n", "0.782008 | \n", "-1.049951 | \n", "0.544980 | \n", "
50% | \n", "7.086443 | \n", "9.045812 | \n", "0.811261 | \n", "-0.938765 | \n", "0.567080 | \n", "
75% | \n", "7.495309 | \n", "9.530408 | \n", "0.848900 | \n", "-0.849266 | \n", "0.586134 | \n", "
max | \n", "8.663801 | \n", "10.995004 | \n", "0.984343 | \n", "-0.591799 | \n", "0.630479 | \n", "