{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e247809e-7610-487b-88e0-9b4947e92c6b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "from PIL import Image\n", "\n", "MAX_VALUE = 107.49169921875" ] }, { "cell_type": "code", "execution_count": 2, "id": "4c4182a2-0284-4a82-a494-cbe4051ff7bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 3, "id": "b7402720-de05-45e9-b076-04780a513fc3", "metadata": {}, "outputs": [], "source": [ "class GrayScaleDataset(Dataset):\n", " def __init__(self, data_dir):\n", " self.data_dir = data_dir\n", " self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n", "\n", " def __len__(self):\n", " return len(self.file_list)\n", "\n", " def __getitem__(self, idx):\n", " file_path = os.path.join(self.data_dir, self.file_list[idx])\n", " data = np.load(file_path)[:,:,0] / MAX_VALUE\n", " return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3", "metadata": {}, "outputs": [], "source": [ "class NO2Dataset(Dataset):\n", " \n", " def __init__(self, image_dir, mask_dir):\n", " \n", " self.image_dir = image_dir\n", " self.mask_dir = mask_dir\n", " self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n", " self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n", " \n", " def __len__(self):\n", " \n", " return len(self.image_filenames)\n", " \n", " def __getitem__(self, idx):\n", " \n", " image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n", " mask_idx = idx % len(self.mask_filenames)\n", " mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n", "\n", " # 加载图像数据 (.npy 文件)\n", " image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n", "\n", " # 加载掩码数据 (.jpg 文件)\n", " mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n", "\n", " # 将掩码数据中非0值设为1,0值保持不变\n", " mask = np.where(mask != 0, 1.0, 0.0)\n", "\n", " # 保持掩码数据形状为 (96, 96, 1)\n", " mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n", "\n", " # 应用掩码\n", " masked_image = image.copy()\n", " masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n", "\n", " # cGAN的输入和目标\n", " X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n", " y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n", "\n", " # 转换形状为 (channels, height, width)\n", " X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n", " mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n", "\n", " return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 5, "id": "9ec25dc1-3728-4b0b-8403-ccad10355999", "metadata": {}, "outputs": [], "source": [ "class PatchMasking:\n", " def __init__(self, patch_size, mask_ratio):\n", " self.patch_size = patch_size\n", " self.mask_ratio = mask_ratio\n", "\n", " def __call__(self, x):\n", " batch_size, C, H, W = x.shape\n", " num_patches = (H // self.patch_size) * (W // self.patch_size)\n", " num_masked = int(num_patches * self.mask_ratio)\n", " \n", " # 为每个样本生成独立的mask\n", " masks = []\n", " for _ in range(batch_size):\n", " mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n", " mask[:num_masked] = 1\n", " mask = mask[torch.randperm(num_patches)]\n", " mask = mask.view(H // self.patch_size, W // self.patch_size)\n", " mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n", " masks.append(mask)\n", " \n", " # 将所有mask堆叠成一个批量张量\n", " masks = torch.stack(masks, dim=0)\n", " masks = torch.unsqueeze(masks, dim=1)\n", " \n", " # 应用mask到输入x上\n", " masked_x = x * (1 - masks.float())\n", " return masked_x, masks" ] }, { "cell_type": "code", "execution_count": 6, "id": "a1f70780-9e31-4917-9785-768140e5610e", "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, input_dim, output_dim):\n", " super(MLP, self).__init__()\n", " self.fc1 = nn.Linear(input_dim, output_dim)\n", " self.act = nn.GELU() # 使用 GELU 激活函数\n", " self.fc2 = nn.Linear(output_dim, input_dim)\n", "\n", " def forward(self, x):\n", " return self.fc2(self.act(self.fc1(x)))\n", "\n", "class Attention(nn.Module):\n", " def __init__(self, dim, heads):\n", " super(Attention, self).__init__()\n", " self.heads = heads\n", " self.dim = dim\n", " self.scale = dim ** -0.5\n", "\n", " self.qkv = nn.Linear(dim, dim * 3)\n", " self.attn_drop = nn.Dropout(0.1)\n", " self.proj = nn.Linear(dim, dim)\n", " self.proj_drop = nn.Dropout(0.1)\n", "\n", " def forward(self, x):\n", " B, N, C = x.shape\n", " qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)\n", " q, k, v = qkv[0], qkv[1], qkv[2]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self.scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self.attn_drop(attn)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B, N, C)\n", " return self.proj_drop(self.proj(out))\n", "\n", "class LayerNorm(nn.Module):\n", " def __init__(self, dim, eps=1e-6):\n", " super(LayerNorm, self).__init__()\n", " self.ln = nn.LayerNorm(dim, eps=eps)\n", "\n", " def forward(self, x):\n", " return self.ln(x)\n", "\n", "class Dropout(nn.Module):\n", " def __init__(self, p=0.1):\n", " super(Dropout, self).__init__()\n", " self.dropout = nn.Dropout(p)\n", "\n", " def forward(self, x):\n", " return self.dropout(x)\n", "\n", "class ViTEncoder(nn.Module):\n", " def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256, dropout=0.1):\n", " super(ViTEncoder, self).__init__()\n", " self.patch_size = patch_size\n", " self.dim = dim\n", " self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n", "\n", " self.attention_layers = nn.ModuleList([\n", " nn.ModuleList([\n", " LayerNorm(dim), # Layer Normalization\n", " Attention(dim, heads),\n", " Dropout(dropout), # Dropout\n", " LayerNorm(dim), # Layer Normalization\n", " MLP(dim, mlp_dim),\n", " Dropout(dropout) # Dropout\n", " ]) for _ in range(depth)\n", " ])\n", "\n", " def forward(self, x):\n", " x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)\n", " x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n", "\n", " for norm1, attn, drop1, norm2, mlp, drop2 in self.attention_layers:\n", " x = x + drop1(attn(norm1(x))) # 残差连接\n", " x = x + drop2(mlp(norm2(x))) # 残差连接\n", " return x\n", "\n", "\n", "class ConvDecoder(nn.Module):\n", " def __init__(self, dim=128, patch_size=8, img_size=96):\n", " super(ConvDecoder, self).__init__()\n", " self.dim = dim\n", " self.patch_size = patch_size\n", " self.img_size = img_size\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(self.dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " \n", " nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, x):\n", " x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n", " x = self.decoder(x)\n", " return x\n", "\n", "class MAEModel(nn.Module):\n", " def __init__(self, encoder, decoder):\n", " super(MAEModel, self).__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " # self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n", "\n", " def forward(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded" ] }, { "cell_type": "code", "execution_count": 7, "id": "4a1427a1-bf38-483e-b92b-07631078c78a", "metadata": {}, "outputs": [], "source": [ "def masked_mse_loss(preds, target, mask):\n", " loss = (preds - target) ** 2\n", " loss = loss.mean(dim=-1) # 对每个像素点求平均\n", " loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n", " return loss" ] }, { "cell_type": "code", "execution_count": 8, "id": "29deee2c-5771-498a-b01b-fde5e0f387ba", "metadata": {}, "outputs": [], "source": [ "def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n", " best_model = model\n", " best_loss = 100\n", " model.to(device)\n", " for epoch in range(epochs):\n", " model.train()\n", " train_loss = 0\n", " for data in train_loader:\n", " data = data.to(device)\n", " optimizer.zero_grad()\n", " masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " loss = masked_mse_loss(output, data, mask)\n", " loss.backward()\n", " optimizer.step()\n", " train_loss += loss.item()\n", " train_loss /= len(train_loader)\n", "\n", " model.eval()\n", " val_loss = 0\n", " with torch.no_grad():\n", " for data in val_loader:\n", " data = data.to(device)\n", " masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " loss = masked_mse_loss(output, data, mask)\n", " val_loss += loss.item()\n", " val_loss /= len(val_loader)\n", " if val_loss < best_loss:\n", " best_loss = val_loss\n", " best_model = model\n", "\n", " print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')" ] }, { "cell_type": "code", "execution_count": 9, "id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b", "metadata": {}, "outputs": [], "source": [ "train_dir = './out_mat/96/train/'\n", "train_dataset = GrayScaleDataset(train_dir)\n", "\n", "val_dir = './out_mat/96/valid/'\n", "val_dataset = GrayScaleDataset(val_dir)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)\n", "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 10, "id": "0d7e2f83-c113-4c62-91eb-d4ea3192530c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 43, "id": "7d6d07a4-31f1-4350-a487-b583db979381", "metadata": {}, "outputs": [], "source": [ "encoder = ViTEncoder()\n", "decoder = ConvDecoder()\n", "model = MAEModel(encoder, decoder)\n", "criterion = nn.MSELoss()\n", "optimizer = optim.Adam(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "id": "09b04e16-3257-4890-b736-a6c7274561e0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Train Loss: 0.9251, Val Loss: 0.0869\n", "Epoch 2, Train Loss: 0.0734, Val Loss: 0.0506\n", "Epoch 3, Train Loss: 0.0494, Val Loss: 0.0489\n", "Epoch 4, Train Loss: 0.0432, Val Loss: 0.0462\n", "Epoch 5, Train Loss: 0.0390, Val Loss: 0.0400\n", "Epoch 6, Train Loss: 0.0351, Val Loss: 0.0356\n" ] } ], "source": [ "train_model(model, train_loader, val_loader, epochs=50, criterion=criterion, optimizer=optimizer, device=device)" ] }, { "cell_type": "code", "execution_count": 18, "id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635", "metadata": {}, "outputs": [], "source": [ "test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n", "test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 19, "id": "56653f37-899a-47d6-8d50-e456b4ad1835", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": 20, "id": "2504300a-ac91-453a-9bfb-ab89f56d4ff6", "metadata": {}, "outputs": [], "source": [ "def cal_ioa(y_true, y_pred):\n", " # 计算平均值\n", " mean_observed = np.mean(y_true)\n", " mean_predicted = np.mean(y_pred)\n", "\n", " # 计算IoA\n", " numerator = np.sum((y_true - y_pred) ** 2)\n", " denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n", " IoA = 1 - (numerator / denominator)\n", "\n", " return IoA" ] }, { "cell_type": "code", "execution_count": 24, "id": "d55be844-0873-4d9a-8160-22603de32a81", "metadata": {}, "outputs": [], "source": [ "test_set2 = GrayScaleDataset('./out_mat/96/test/')\n", "test_loader2 = DataLoader(test_set2, batch_size=64, shuffle=False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 34, "id": "f83dbc37-8543-45bc-ba59-ca88d4ba2a66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 96, 96])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rev_data.shape" ] }, { "cell_type": "code", "execution_count": 35, "id": "2655f7f4-9d88-49fd-9346-87a621320183", "metadata": {}, "outputs": [], "source": [ "eva_list = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for data in test_loader2:\n", " data = data.to(device)\n", " masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " rev_data = data * MAX_VALUE\n", " rev_recon = output * MAX_VALUE\n", " data_label = rev_data * mask\n", " data_label = data_label[mask==1]\n", " recon_no2 = rev_recon * mask\n", " recon_no2 = recon_no2[mask==1]\n", " y_true = rev_data.flatten()\n", " y_pred = rev_recon.flatten()\n", " mae = mean_absolute_error(y_true, y_pred)\n", " rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n", " mape = mean_absolute_percentage_error(y_true, y_pred)\n", " r2 = r2_score(y_true, y_pred)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " eva_list.append([mae, rmse, mape, r2, ioa])" ] }, { "cell_type": "code", "execution_count": 36, "id": "0327b51c-d714-4fe0-a044-d8be3ff180e0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioa
count75.00000075.00000075.00000075.00000075.000000
mean1.2080131.6006440.1427200.9419830.981683
std0.0562350.0817910.0034350.0044490.002309
min1.0915171.4463890.1348490.9118330.965708
25%1.1703051.5550510.1405190.9404250.981100
50%1.2047281.5932610.1429810.9426510.982003
75%1.2427621.6463110.1451850.9441180.982809
max1.4207212.0379030.1505660.9496630.984610
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa\n", "count 75.000000 75.000000 75.000000 75.000000 75.000000\n", "mean 1.208013 1.600644 0.142720 0.941983 0.981683\n", "std 0.056235 0.081791 0.003435 0.004449 0.002309\n", "min 1.091517 1.446389 0.134849 0.911833 0.965708\n", "25% 1.170305 1.555051 0.140519 0.940425 0.981100\n", "50% 1.204728 1.593261 0.142981 0.942651 0.982003\n", "75% 1.242762 1.646311 0.145185 0.944118 0.982809\n", "max 1.420721 2.037903 0.150566 0.949663 0.984610" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()" ] }, { "cell_type": "code", "execution_count": 38, "id": "b8669b7e-6974-418a-87fc-074734f9a1a3", "metadata": {}, "outputs": [], "source": [ "eva_list_frame = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for data in test_loader2:\n", " data = data.to(device)\n", " masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n", " output = model(masked_data)\n", " rev_data = data * MAX_VALUE\n", " rev_recon = output * MAX_VALUE\n", " # todo: 这里需要只评估修补出来的模块\n", " for i, sample in enumerate(rev_data):\n", " used_mask = mask[i]\n", " data_label = sample[0] * used_mask\n", " recon_no2 = rev_recon[i][0] * used_mask\n", " data_label = data_label[used_mask==1]\n", " recon_no2 = recon_no2[used_mask==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " r = np.corrcoef(data_label, recon_no2)[0, 1]\n", " eva_list_frame.append([mae, rmse, mape, r2, ioa, r])" ] }, { "cell_type": "code", "execution_count": 39, "id": "c9079fc5-6ab3-465e-9067-6cad8f69c5a8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioar
count4739.0000004739.0000004739.0000004739.0000004739.0000004739.000000
mean1.7143152.3501890.2159740.6094700.9435600.823401
std0.6973440.9403450.0778930.1314960.0222610.069394
min0.6360490.8217230.0999990.0031940.8022370.405363
25%1.1216171.5766690.1709740.5330810.9316530.783616
50%1.4597202.1323160.1994190.6237690.9469520.831403
75%2.3347613.1193930.2345170.6985170.9589430.872422
max4.4062588.4701091.2426360.8951990.9869010.965110
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa \\\n", "count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n", "mean 1.714315 2.350189 0.215974 0.609470 0.943560 \n", "std 0.697344 0.940345 0.077893 0.131496 0.022261 \n", "min 0.636049 0.821723 0.099999 0.003194 0.802237 \n", "25% 1.121617 1.576669 0.170974 0.533081 0.931653 \n", "50% 1.459720 2.132316 0.199419 0.623769 0.946952 \n", "75% 2.334761 3.119393 0.234517 0.698517 0.958943 \n", "max 4.406258 8.470109 1.242636 0.895199 0.986901 \n", "\n", " r \n", "count 4739.000000 \n", "mean 0.823401 \n", "std 0.069394 \n", "min 0.405363 \n", "25% 0.783616 \n", "50% 0.831403 \n", "75% 0.872422 \n", "max 0.965110 " ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')\n", "eva_frame_df.describe()" ] }, { "cell_type": "code", "execution_count": null, "id": "e840b789-bf68-4b4d-a8d3-c5362c310349", "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "too many values to unpack (expected 3)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_idx, (X, y, mask) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(test_loader2):\n\u001b[1;32m 6\u001b[0m X, y, mask \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mto(device), y\u001b[38;5;241m.\u001b[39mto(device), mask\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 7\u001b[0m mask_rev \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39msqueeze(mask, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# mask取反获得修复区域\u001b[39;00m\n", "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 3)" ] } ], "source": [ "eva_list = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader2):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = torch.squeeze(y * MAX_VALUE, dim=1)\n", " rev_recon = torch.squeeze(reconstructed * MAX_VALUE, dim=1)\n", " # todo: 这里需要只评估修补出来的模块\n", " data_label = rev_data * mask_rev\n", " data_label = data_label[mask_rev==1]\n", " recon_no2 = rev_recon * mask_rev\n", " recon_no2 = recon_no2[mask_rev==1]\n", " y_true = rev_data.flatten()\n", " y_pred = rev_recon.flatten()\n", " mae = mean_absolute_error(y_true, y_pred)\n", " rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n", " mape = mean_absolute_percentage_error(y_true, y_pred)\n", " r2 = r2_score(y_true, y_pred)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " eva_list.append([mae, rmse, mape, r2, ioa])" ] }, { "cell_type": "code", "execution_count": 23, "id": "41fa754d-1eee-43a2-9e39-a0254719be30", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioa
count149.000000149.000000149.000000149.000000149.000000
mean2.2356624.0423490.2384940.6260600.572341
std0.1927090.3574750.0074050.0428900.042652
min1.7865673.1671430.2247960.5221570.460707
25%2.0841173.7792760.2329740.5977740.547144
50%2.2260624.0754650.2374290.6275880.570579
75%2.3614114.2845230.2438660.6562260.601233
max2.7513774.9174070.2582300.7409430.666083
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa\n", "count 149.000000 149.000000 149.000000 149.000000 149.000000\n", "mean 2.235662 4.042349 0.238494 0.626060 0.572341\n", "std 0.192709 0.357475 0.007405 0.042890 0.042652\n", "min 1.786567 3.167143 0.224796 0.522157 0.460707\n", "25% 2.084117 3.779276 0.232974 0.597774 0.547144\n", "50% 2.226062 4.075465 0.237429 0.627588 0.570579\n", "75% 2.361411 4.284523 0.243866 0.656226 0.601233\n", "max 2.751377 4.917407 0.258230 0.740943 0.666083" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()" ] }, { "cell_type": "code", "execution_count": 23, "id": "46cffa4a-37bc-4e13-9723-fc6cb244c95c", "metadata": {}, "outputs": [], "source": [ "eva_list_frame = list()\n", "device = 'cpu'\n", "model = model.to(device)\n", "best_mape = 1\n", "best_img = None\n", "best_mask = None\n", "best_recov = None\n", "with torch.no_grad():\n", " for batch_idx, (X, y, mask) in enumerate(test_loader):\n", " X, y, mask = X.to(device), y.to(device), mask.to(device)\n", " mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n", " reconstructed = model(X)\n", " rev_data = y * MAX_VALUE\n", " rev_recon = reconstructed * MAX_VALUE\n", " # todo: 这里需要只评估修补出来的模块\n", " for i, sample in enumerate(rev_data):\n", " used_mask = mask_rev[i]\n", " data_label = sample[0] * used_mask\n", " recon_no2 = rev_recon[i][0] * used_mask\n", " data_label = data_label[used_mask==1]\n", " recon_no2 = recon_no2[used_mask==1]\n", " mae = mean_absolute_error(data_label, recon_no2)\n", " rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n", " mape = mean_absolute_percentage_error(data_label, recon_no2)\n", " r2 = r2_score(data_label, recon_no2)\n", " ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n", " r = np.corrcoef(data_label, recon_no2)[0, 1]\n", " eva_list_frame.append([mae, rmse, mape, r2, ioa, r])" ] }, { "cell_type": "code", "execution_count": null, "id": "b6f6a897-2f48-4958-8725-f566430c61e1", "metadata": {}, "outputs": [], "source": [ "eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')" ] }, { "cell_type": "code", "execution_count": 28, "id": "6d1920e7-b92f-414e-8273-0b4666587904", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
maermsemaper2ioar
count4739.0000004739.0000004739.0000004739.0000004739.0000004739.000000
mean5.9200176.8642450.603656-2.7430170.2285800.225978
std3.5346483.8450340.2246792.0497530.3706220.227965
min1.4773801.8493920.271934-22.827546-1.899284-0.626938
25%2.9757003.6005210.502338-3.6317020.0428750.088760
50%4.1690985.0558900.558942-2.2335300.3095920.253954
75%8.6167989.8090690.632651-1.2876020.5099370.389390
max18.84077520.3710253.6898530.0242940.8353390.782481
\n", "
" ], "text/plain": [ " mae rmse mape r2 ioa \\\n", "count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n", "mean 5.920017 6.864245 0.603656 -2.743017 0.228580 \n", "std 3.534648 3.845034 0.224679 2.049753 0.370622 \n", "min 1.477380 1.849392 0.271934 -22.827546 -1.899284 \n", "25% 2.975700 3.600521 0.502338 -3.631702 0.042875 \n", "50% 4.169098 5.055890 0.558942 -2.233530 0.309592 \n", "75% 8.616798 9.809069 0.632651 -1.287602 0.509937 \n", "max 18.840775 20.371025 3.689853 0.024294 0.835339 \n", "\n", " r \n", "count 4739.000000 \n", "mean 0.225978 \n", "std 0.227965 \n", "min -0.626938 \n", "25% 0.088760 \n", "50% 0.253954 \n", "75% 0.389390 \n", "max 0.782481 " ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eva_frame_df.describe()" ] }, { "cell_type": "code", "execution_count": 31, "id": "d1696b4f-1520-4201-b855-63f517022ec3", "metadata": {}, "outputs": [], "source": [ "# 可视化特定特征的函数\n", "def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n", " plt.figure(figsize=(12, 6))\n", " plt.subplot(1, 4, 1)\n", " plt.imshow(input_feature, cmap='RdYlGn_r')\n", " plt.gca().axis('off') # 获取当前坐标轴并关闭\n", " plt.subplot(1, 4, 2)\n", " plt.imshow(masked_feature, cmap='gray')\n", " plt.gca().axis('off') # 获取当前坐标轴并关闭\n", " plt.subplot(1, 4, 3)\n", " plt.imshow(recov_region, cmap='RdYlGn_r')\n", " plt.gca().axis('off') # 获取当前坐标轴并关闭\n", " plt.subplot(1, 4, 4)\n", " plt.imshow(output_feature, cmap='RdYlGn_r')\n", " plt.gca().axis('off') # 获取当前坐标轴并关闭\n", " plt.savefig('./figures/result/vitmae_20_samples.png', bbox_inches='tight')" ] }, { "cell_type": "code", "execution_count": 92, "id": "56291a37-cc49-428f-a8db-99bdd7a1f062", "metadata": {}, "outputs": [], "source": [ "torch.save(model, './models/MAE/vit.pt')" ] }, { "cell_type": "code", "execution_count": 41, "id": "bb8eccaa-7409-4cce-9119-70aed5ee496e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'1114', '1952', '2568', '3523', '602'}" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n", "find_ex" ] }, { "cell_type": "code", "execution_count": 42, "id": "b9c6cdba-e563-42e2-885d-d1df320dac02", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for j in find_ex:\n", " ori = np.load(f'./test_img/{j}-real.npy')[0]\n", " mask = np.load(f'./test_img/{j}-mask.npy')\n", " mask_rev = 1 - mask\n", " img_in = ori * mask_rev / MAX_VALUE\n", " img_out = model(torch.tensor(img_in.reshape(1, 1, 96, 96), dtype=torch.float32)).detach().cpu().numpy()[0][0] * MAX_VALUE\n", " out = ori * mask_rev + img_out * mask\n", " plt.imshow(out, cmap='RdYlGn_r')\n", " plt.gca().axis('off')\n", " plt.savefig(f'./test_img/out_fig/{j}-mae_vit_out.png', bbox_inches='tight')\n", " plt.clf()" ] }, { "cell_type": "code", "execution_count": null, "id": "e3f24f2a-bc47-409d-8e46-bfa62851701b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }