MAE_ATMO/torch_MAE_1d_ViT.ipynb

609 lines
30 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "fa295d87-946f-402b-9d97-1127ee9a33a0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"from PIL import Image\n",
"\n",
"MAX_VALUE = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c6dd8e35-02e3-491c-b4be-a874cf1054ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2f151caf-43d1-4d59-a111-96ad5e6bc38b",
"metadata": {},
"outputs": [],
"source": [
"class GrayScaleDataset(Dataset):\n",
" def __init__(self, data_dir):\n",
" self.data_dir = data_dir\n",
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
"\n",
" def __len__(self):\n",
" return len(self.file_list)\n",
"\n",
" def __getitem__(self, idx):\n",
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
" data = np.load(file_path)[:,:,0] / MAX_VALUE\n",
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3",
"metadata": {},
"outputs": [],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = idx % len(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "36752a6d-329a-464d-a329-f02206bf63b0",
"metadata": {},
"outputs": [],
"source": [
"class PatchMasking:\n",
" def __init__(self, patch_size, mask_ratio):\n",
" self.patch_size = patch_size\n",
" self.mask_ratio = mask_ratio\n",
"\n",
" def __call__(self, x):\n",
" batch_size, C, H, W = x.shape\n",
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
" num_masked = int(num_patches * self.mask_ratio)\n",
" \n",
" # 为每个样本生成独立的mask\n",
" masks = []\n",
" for _ in range(batch_size):\n",
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
" mask[:num_masked] = 1\n",
" mask = mask[torch.randperm(num_patches)]\n",
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
" masks.append(mask)\n",
" \n",
" # 将所有mask堆叠成一个批量张量\n",
" masks = torch.stack(masks, dim=0)\n",
" masks = torch.unsqueeze(masks, dim=1)\n",
" \n",
" # 应用mask到输入x上\n",
" masked_x = x * (1 - masks.float())\n",
" return masked_x, masks"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0db0d920-8de2-4bad-9b99-67eed152644d",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cb27d3a7-77ed-4110-96bd-bcc4880964d2",
"metadata": {},
"outputs": [],
"source": [
"class ViTEncoder(nn.Module):\n",
" def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):\n",
" super(ViTEncoder, self).__init__()\n",
" self.patch_size = patch_size\n",
" self.dim = dim\n",
" self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n",
" \n",
" # 定义 Transformer 编码器层\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)\n",
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n",
"\n",
" def forward(self, x):\n",
" x = self.patch_embedding(x)\n",
" x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n",
" x = self.transformer_encoder(x)\n",
" return x\n",
"\n",
"class ConvDecoder(nn.Module):\n",
" def __init__(self, dim=128, patch_size=8, img_size=96):\n",
" super(ConvDecoder, self).__init__()\n",
" self.dim = dim\n",
" self.patch_size = patch_size\n",
" self.img_size = img_size\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" # x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n",
" x = self.decoder(x)\n",
" return x\n",
"\n",
"class MAEModel(nn.Module):\n",
" def __init__(self, encoder, decoder):\n",
" super(MAEModel, self).__init__()\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "783e62af-7f6a-40bd-a423-be63fe98a655",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "baeffdf0-cdc2-44c4-972a-e2e671635d6a",
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
" model.to(device)\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for data in train_loader:\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_loss /= len(train_loader)\n",
"\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for data in val_loader:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" val_loss += loss.item()\n",
" val_loss /= len(val_loader)\n",
"\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b",
"metadata": {},
"outputs": [],
"source": [
"train_dir = './out_mat/96/train/'\n",
"train_dataset = GrayScaleDataset(train_dir)\n",
"\n",
"val_dir = './out_mat/96/valid/'\n",
"val_dataset = GrayScaleDataset(val_dir)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "7d6d07a4-31f1-4350-a487-b583db979381",
"metadata": {},
"outputs": [],
"source": [
"encoder = ViTEncoder()\n",
"decoder = ConvDecoder()\n",
"model = MAEModel(encoder, decoder)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8ee33651-f5f0-4b92-96e9-a84e32725b44",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([128, 128, 6, 6])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.transpose(1, 2).reshape(-1, 128, 6, 6).shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a5684758-bc6d-45b0-b885-da37820ca5ac",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m 2\u001b[0m a \u001b[38;5;241m=\u001b[39m encoder(i)\n\u001b[0;32m----> 3\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m c \u001b[38;5;241m=\u001b[39m decoder(b)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"Cell \u001b[0;32mIn[12], line 13\u001b[0m, in \u001b[0;36mMlp.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 13\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(x)\n\u001b[1;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrop(x)\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 454\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 455\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead"
]
}
],
"source": [
"for i in train_loader:\n",
" a = encoder(i)\n",
" b = model.mlp(a)\n",
" c = decoder(b)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09b04e16-3257-4890-b736-a6c7274561e0",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"train_model(model, train_loader, val_loader, epochs=100, criterion=criterion, optimizer=optimizer, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635",
"metadata": {},
"outputs": [],
"source": [
"test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
"test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "56653f37-899a-47d6-8d50-e456b4ad1835",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f1ecbd05-7aa3-43ae-8bc2-aa44d19689b9",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e840b789-bf68-4b4d-a8d3-c5362c310349",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * MAX_VALUE\n",
" rev_recon = reconstructed * MAX_VALUE\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "41fa754d-1eee-43a2-9e39-a0254719be30",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>7.068207</td>\n",
" <td>9.016465</td>\n",
" <td>0.814727</td>\n",
" <td>-0.952793</td>\n",
" <td>0.564749</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.659118</td>\n",
" <td>0.774556</td>\n",
" <td>0.054147</td>\n",
" <td>0.162851</td>\n",
" <td>0.033048</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>5.609327</td>\n",
" <td>7.113544</td>\n",
" <td>0.599120</td>\n",
" <td>-1.402735</td>\n",
" <td>0.461420</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>6.613351</td>\n",
" <td>8.499699</td>\n",
" <td>0.782008</td>\n",
" <td>-1.049951</td>\n",
" <td>0.544980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>7.086443</td>\n",
" <td>9.045812</td>\n",
" <td>0.811261</td>\n",
" <td>-0.938765</td>\n",
" <td>0.567080</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>7.495309</td>\n",
" <td>9.530408</td>\n",
" <td>0.848900</td>\n",
" <td>-0.849266</td>\n",
" <td>0.586134</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>8.663801</td>\n",
" <td>10.995004</td>\n",
" <td>0.984343</td>\n",
" <td>-0.591799</td>\n",
" <td>0.630479</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa\n",
"count 149.000000 149.000000 149.000000 149.000000 149.000000\n",
"mean 7.068207 9.016465 0.814727 -0.952793 0.564749\n",
"std 0.659118 0.774556 0.054147 0.162851 0.033048\n",
"min 5.609327 7.113544 0.599120 -1.402735 0.461420\n",
"25% 6.613351 8.499699 0.782008 -1.049951 0.544980\n",
"50% 7.086443 9.045812 0.811261 -0.938765 0.567080\n",
"75% 7.495309 9.530408 0.848900 -0.849266 0.586134\n",
"max 8.663801 10.995004 0.984343 -0.591799 0.630479"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b15bbdc-cb87-4648-b22f-72917b8c1e6b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}