{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e247809e-7610-487b-88e0-9b4947e92c6b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"from PIL import Image\n",
"\n",
"MAX_VALUE = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4c4182a2-0284-4a82-a494-cbe4051ff7bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b7402720-de05-45e9-b076-04780a513fc3",
"metadata": {},
"outputs": [],
"source": [
"class GrayScaleDataset(Dataset):\n",
" def __init__(self, data_dir):\n",
" self.data_dir = data_dir\n",
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
"\n",
" def __len__(self):\n",
" return len(self.file_list)\n",
"\n",
" def __getitem__(self, idx):\n",
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
" data = np.load(file_path)[:,:,0] / MAX_VALUE\n",
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3",
"metadata": {},
"outputs": [],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = idx % len(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为1,0值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9ec25dc1-3728-4b0b-8403-ccad10355999",
"metadata": {},
"outputs": [],
"source": [
"class PatchMasking:\n",
" def __init__(self, patch_size, mask_ratio):\n",
" self.patch_size = patch_size\n",
" self.mask_ratio = mask_ratio\n",
"\n",
" def __call__(self, x):\n",
" batch_size, C, H, W = x.shape\n",
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
" num_masked = int(num_patches * self.mask_ratio)\n",
" \n",
" # 为每个样本生成独立的mask\n",
" masks = []\n",
" for _ in range(batch_size):\n",
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
" mask[:num_masked] = 1\n",
" mask = mask[torch.randperm(num_patches)]\n",
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
" masks.append(mask)\n",
" \n",
" # 将所有mask堆叠成一个批量张量\n",
" masks = torch.stack(masks, dim=0)\n",
" masks = torch.unsqueeze(masks, dim=1)\n",
" \n",
" # 应用mask到输入x上\n",
" masked_x = x * (1 - masks.float())\n",
" return masked_x, masks"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a1f70780-9e31-4917-9785-768140e5610e",
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self, input_dim, output_dim):\n",
" super(MLP, self).__init__()\n",
" self.fc1 = nn.Linear(input_dim, output_dim)\n",
" self.act = nn.GELU() # 使用 GELU 激活函数\n",
" self.fc2 = nn.Linear(output_dim, input_dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fc2(self.act(self.fc1(x)))\n",
"\n",
"class Attention(nn.Module):\n",
" def __init__(self, dim, heads):\n",
" super(Attention, self).__init__()\n",
" self.heads = heads\n",
" self.dim = dim\n",
" self.scale = dim ** -0.5\n",
"\n",
" self.qkv = nn.Linear(dim, dim * 3)\n",
" self.attn_drop = nn.Dropout(0.1)\n",
" self.proj = nn.Linear(dim, dim)\n",
" self.proj_drop = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x):\n",
" B, N, C = x.shape\n",
" qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)\n",
" q, k, v = qkv[0], qkv[1], qkv[2]\n",
"\n",
" attn = (q @ k.transpose(-2, -1)) * self.scale\n",
" attn = attn.softmax(dim=-1)\n",
" attn = self.attn_drop(attn)\n",
"\n",
" out = (attn @ v).transpose(1, 2).reshape(B, N, C)\n",
" return self.proj_drop(self.proj(out))\n",
"\n",
"class LayerNorm(nn.Module):\n",
" def __init__(self, dim, eps=1e-6):\n",
" super(LayerNorm, self).__init__()\n",
" self.ln = nn.LayerNorm(dim, eps=eps)\n",
"\n",
" def forward(self, x):\n",
" return self.ln(x)\n",
"\n",
"class Dropout(nn.Module):\n",
" def __init__(self, p=0.1):\n",
" super(Dropout, self).__init__()\n",
" self.dropout = nn.Dropout(p)\n",
"\n",
" def forward(self, x):\n",
" return self.dropout(x)\n",
"\n",
"class ViTEncoder(nn.Module):\n",
" def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256, dropout=0.1):\n",
" super(ViTEncoder, self).__init__()\n",
" self.patch_size = patch_size\n",
" self.dim = dim\n",
" self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n",
"\n",
" self.attention_layers = nn.ModuleList([\n",
" nn.ModuleList([\n",
" LayerNorm(dim), # Layer Normalization\n",
" Attention(dim, heads),\n",
" Dropout(dropout), # Dropout\n",
" LayerNorm(dim), # Layer Normalization\n",
" MLP(dim, mlp_dim),\n",
" Dropout(dropout) # Dropout\n",
" ]) for _ in range(depth)\n",
" ])\n",
"\n",
" def forward(self, x):\n",
" x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)\n",
" x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n",
"\n",
" for norm1, attn, drop1, norm2, mlp, drop2 in self.attention_layers:\n",
" x = x + drop1(attn(norm1(x))) # 残差连接\n",
" x = x + drop2(mlp(norm2(x))) # 残差连接\n",
" return x\n",
"\n",
"\n",
"class ConvDecoder(nn.Module):\n",
" def __init__(self, dim=128, patch_size=8, img_size=96):\n",
" super(ConvDecoder, self).__init__()\n",
" self.dim = dim\n",
" self.patch_size = patch_size\n",
" self.img_size = img_size\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(self.dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n",
" x = self.decoder(x)\n",
" return x\n",
"\n",
"class MAEModel(nn.Module):\n",
" def __init__(self, encoder, decoder):\n",
" super(MAEModel, self).__init__()\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" # self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4a1427a1-bf38-483e-b92b-07631078c78a",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "29deee2c-5771-498a-b01b-fde5e0f387ba",
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
" best_model = model\n",
" best_loss = 100\n",
" model.to(device)\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for data in train_loader:\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_loss /= len(train_loader)\n",
"\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for data in val_loader:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" val_loss += loss.item()\n",
" val_loss /= len(val_loader)\n",
" if val_loss < best_loss:\n",
" best_loss = val_loss\n",
" best_model = model\n",
"\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b",
"metadata": {},
"outputs": [],
"source": [
"train_dir = './out_mat/96/train/'\n",
"train_dataset = GrayScaleDataset(train_dir)\n",
"\n",
"val_dir = './out_mat/96/valid/'\n",
"val_dataset = GrayScaleDataset(val_dir)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)\n",
"val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0d7e2f83-c113-4c62-91eb-d4ea3192530c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "7d6d07a4-31f1-4350-a487-b583db979381",
"metadata": {},
"outputs": [],
"source": [
"encoder = ViTEncoder()\n",
"decoder = ConvDecoder()\n",
"model = MAEModel(encoder, decoder)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09b04e16-3257-4890-b736-a6c7274561e0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 0.9251, Val Loss: 0.0869\n",
"Epoch 2, Train Loss: 0.0734, Val Loss: 0.0506\n",
"Epoch 3, Train Loss: 0.0494, Val Loss: 0.0489\n",
"Epoch 4, Train Loss: 0.0432, Val Loss: 0.0462\n",
"Epoch 5, Train Loss: 0.0390, Val Loss: 0.0400\n",
"Epoch 6, Train Loss: 0.0351, Val Loss: 0.0356\n"
]
}
],
"source": [
"train_model(model, train_loader, val_loader, epochs=50, criterion=criterion, optimizer=optimizer, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635",
"metadata": {},
"outputs": [],
"source": [
"test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
"test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "56653f37-899a-47d6-8d50-e456b4ad1835",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2504300a-ac91-453a-9bfb-ab89f56d4ff6",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "d55be844-0873-4d9a-8160-22603de32a81",
"metadata": {},
"outputs": [],
"source": [
"test_set2 = GrayScaleDataset('./out_mat/96/test/')\n",
"test_loader2 = DataLoader(test_set2, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "f83dbc37-8543-45bc-ba59-ca88d4ba2a66",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 96, 96])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rev_data.shape"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "2655f7f4-9d88-49fd-9346-87a621320183",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for data in test_loader2:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" rev_data = data * MAX_VALUE\n",
" rev_recon = output * MAX_VALUE\n",
" data_label = rev_data * mask\n",
" data_label = data_label[mask==1]\n",
" recon_no2 = rev_recon * mask\n",
" recon_no2 = recon_no2[mask==1]\n",
" y_true = rev_data.flatten()\n",
" y_pred = rev_recon.flatten()\n",
" mae = mean_absolute_error(y_true, y_pred)\n",
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
" r2 = r2_score(y_true, y_pred)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "0327b51c-d714-4fe0-a044-d8be3ff180e0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" mae | \n",
" rmse | \n",
" mape | \n",
" r2 | \n",
" ioa | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 75.000000 | \n",
" 75.000000 | \n",
" 75.000000 | \n",
" 75.000000 | \n",
" 75.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 1.208013 | \n",
" 1.600644 | \n",
" 0.142720 | \n",
" 0.941983 | \n",
" 0.981683 | \n",
"
\n",
" \n",
" std | \n",
" 0.056235 | \n",
" 0.081791 | \n",
" 0.003435 | \n",
" 0.004449 | \n",
" 0.002309 | \n",
"
\n",
" \n",
" min | \n",
" 1.091517 | \n",
" 1.446389 | \n",
" 0.134849 | \n",
" 0.911833 | \n",
" 0.965708 | \n",
"
\n",
" \n",
" 25% | \n",
" 1.170305 | \n",
" 1.555051 | \n",
" 0.140519 | \n",
" 0.940425 | \n",
" 0.981100 | \n",
"
\n",
" \n",
" 50% | \n",
" 1.204728 | \n",
" 1.593261 | \n",
" 0.142981 | \n",
" 0.942651 | \n",
" 0.982003 | \n",
"
\n",
" \n",
" 75% | \n",
" 1.242762 | \n",
" 1.646311 | \n",
" 0.145185 | \n",
" 0.944118 | \n",
" 0.982809 | \n",
"
\n",
" \n",
" max | \n",
" 1.420721 | \n",
" 2.037903 | \n",
" 0.150566 | \n",
" 0.949663 | \n",
" 0.984610 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mae rmse mape r2 ioa\n",
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
"mean 1.208013 1.600644 0.142720 0.941983 0.981683\n",
"std 0.056235 0.081791 0.003435 0.004449 0.002309\n",
"min 1.091517 1.446389 0.134849 0.911833 0.965708\n",
"25% 1.170305 1.555051 0.140519 0.940425 0.981100\n",
"50% 1.204728 1.593261 0.142981 0.942651 0.982003\n",
"75% 1.242762 1.646311 0.145185 0.944118 0.982809\n",
"max 1.420721 2.037903 0.150566 0.949663 0.984610"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "b8669b7e-6974-418a-87fc-074734f9a1a3",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for data in test_loader2:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" rev_data = data * MAX_VALUE\n",
" rev_recon = output * MAX_VALUE\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "c9079fc5-6ab3-465e-9067-6cad8f69c5a8",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" mae | \n",
" rmse | \n",
" mape | \n",
" r2 | \n",
" ioa | \n",
" r | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 1.714315 | \n",
" 2.350189 | \n",
" 0.215974 | \n",
" 0.609470 | \n",
" 0.943560 | \n",
" 0.823401 | \n",
"
\n",
" \n",
" std | \n",
" 0.697344 | \n",
" 0.940345 | \n",
" 0.077893 | \n",
" 0.131496 | \n",
" 0.022261 | \n",
" 0.069394 | \n",
"
\n",
" \n",
" min | \n",
" 0.636049 | \n",
" 0.821723 | \n",
" 0.099999 | \n",
" 0.003194 | \n",
" 0.802237 | \n",
" 0.405363 | \n",
"
\n",
" \n",
" 25% | \n",
" 1.121617 | \n",
" 1.576669 | \n",
" 0.170974 | \n",
" 0.533081 | \n",
" 0.931653 | \n",
" 0.783616 | \n",
"
\n",
" \n",
" 50% | \n",
" 1.459720 | \n",
" 2.132316 | \n",
" 0.199419 | \n",
" 0.623769 | \n",
" 0.946952 | \n",
" 0.831403 | \n",
"
\n",
" \n",
" 75% | \n",
" 2.334761 | \n",
" 3.119393 | \n",
" 0.234517 | \n",
" 0.698517 | \n",
" 0.958943 | \n",
" 0.872422 | \n",
"
\n",
" \n",
" max | \n",
" 4.406258 | \n",
" 8.470109 | \n",
" 1.242636 | \n",
" 0.895199 | \n",
" 0.986901 | \n",
" 0.965110 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 1.714315 2.350189 0.215974 0.609470 0.943560 \n",
"std 0.697344 0.940345 0.077893 0.131496 0.022261 \n",
"min 0.636049 0.821723 0.099999 0.003194 0.802237 \n",
"25% 1.121617 1.576669 0.170974 0.533081 0.931653 \n",
"50% 1.459720 2.132316 0.199419 0.623769 0.946952 \n",
"75% 2.334761 3.119393 0.234517 0.698517 0.958943 \n",
"max 4.406258 8.470109 1.242636 0.895199 0.986901 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.823401 \n",
"std 0.069394 \n",
"min 0.405363 \n",
"25% 0.783616 \n",
"50% 0.831403 \n",
"75% 0.872422 \n",
"max 0.965110 "
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')\n",
"eva_frame_df.describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e840b789-bf68-4b4d-a8d3-c5362c310349",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "too many values to unpack (expected 3)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_idx, (X, y, mask) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(test_loader2):\n\u001b[1;32m 6\u001b[0m X, y, mask \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mto(device), y\u001b[38;5;241m.\u001b[39mto(device), mask\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 7\u001b[0m mask_rev \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39msqueeze(mask, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# mask取反获得修复区域\u001b[39;00m\n",
"\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 3)"
]
}
],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader2):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = torch.squeeze(y * MAX_VALUE, dim=1)\n",
" rev_recon = torch.squeeze(reconstructed * MAX_VALUE, dim=1)\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = rev_data * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = rev_recon * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" y_true = rev_data.flatten()\n",
" y_pred = rev_recon.flatten()\n",
" mae = mean_absolute_error(y_true, y_pred)\n",
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
" r2 = r2_score(y_true, y_pred)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "41fa754d-1eee-43a2-9e39-a0254719be30",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" mae | \n",
" rmse | \n",
" mape | \n",
" r2 | \n",
" ioa | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 149.000000 | \n",
" 149.000000 | \n",
" 149.000000 | \n",
" 149.000000 | \n",
" 149.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 2.235662 | \n",
" 4.042349 | \n",
" 0.238494 | \n",
" 0.626060 | \n",
" 0.572341 | \n",
"
\n",
" \n",
" std | \n",
" 0.192709 | \n",
" 0.357475 | \n",
" 0.007405 | \n",
" 0.042890 | \n",
" 0.042652 | \n",
"
\n",
" \n",
" min | \n",
" 1.786567 | \n",
" 3.167143 | \n",
" 0.224796 | \n",
" 0.522157 | \n",
" 0.460707 | \n",
"
\n",
" \n",
" 25% | \n",
" 2.084117 | \n",
" 3.779276 | \n",
" 0.232974 | \n",
" 0.597774 | \n",
" 0.547144 | \n",
"
\n",
" \n",
" 50% | \n",
" 2.226062 | \n",
" 4.075465 | \n",
" 0.237429 | \n",
" 0.627588 | \n",
" 0.570579 | \n",
"
\n",
" \n",
" 75% | \n",
" 2.361411 | \n",
" 4.284523 | \n",
" 0.243866 | \n",
" 0.656226 | \n",
" 0.601233 | \n",
"
\n",
" \n",
" max | \n",
" 2.751377 | \n",
" 4.917407 | \n",
" 0.258230 | \n",
" 0.740943 | \n",
" 0.666083 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mae rmse mape r2 ioa\n",
"count 149.000000 149.000000 149.000000 149.000000 149.000000\n",
"mean 2.235662 4.042349 0.238494 0.626060 0.572341\n",
"std 0.192709 0.357475 0.007405 0.042890 0.042652\n",
"min 1.786567 3.167143 0.224796 0.522157 0.460707\n",
"25% 2.084117 3.779276 0.232974 0.597774 0.547144\n",
"50% 2.226062 4.075465 0.237429 0.627588 0.570579\n",
"75% 2.361411 4.284523 0.243866 0.656226 0.601233\n",
"max 2.751377 4.917407 0.258230 0.740943 0.666083"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "46cffa4a-37bc-4e13-9723-fc6cb244c95c",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"best_mape = 1\n",
"best_img = None\n",
"best_mask = None\n",
"best_recov = None\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * MAX_VALUE\n",
" rev_recon = reconstructed * MAX_VALUE\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6f6a897-2f48-4958-8725-f566430c61e1",
"metadata": {},
"outputs": [],
"source": [
"eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "6d1920e7-b92f-414e-8273-0b4666587904",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" mae | \n",
" rmse | \n",
" mape | \n",
" r2 | \n",
" ioa | \n",
" r | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
" 4739.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 5.920017 | \n",
" 6.864245 | \n",
" 0.603656 | \n",
" -2.743017 | \n",
" 0.228580 | \n",
" 0.225978 | \n",
"
\n",
" \n",
" std | \n",
" 3.534648 | \n",
" 3.845034 | \n",
" 0.224679 | \n",
" 2.049753 | \n",
" 0.370622 | \n",
" 0.227965 | \n",
"
\n",
" \n",
" min | \n",
" 1.477380 | \n",
" 1.849392 | \n",
" 0.271934 | \n",
" -22.827546 | \n",
" -1.899284 | \n",
" -0.626938 | \n",
"
\n",
" \n",
" 25% | \n",
" 2.975700 | \n",
" 3.600521 | \n",
" 0.502338 | \n",
" -3.631702 | \n",
" 0.042875 | \n",
" 0.088760 | \n",
"
\n",
" \n",
" 50% | \n",
" 4.169098 | \n",
" 5.055890 | \n",
" 0.558942 | \n",
" -2.233530 | \n",
" 0.309592 | \n",
" 0.253954 | \n",
"
\n",
" \n",
" 75% | \n",
" 8.616798 | \n",
" 9.809069 | \n",
" 0.632651 | \n",
" -1.287602 | \n",
" 0.509937 | \n",
" 0.389390 | \n",
"
\n",
" \n",
" max | \n",
" 18.840775 | \n",
" 20.371025 | \n",
" 3.689853 | \n",
" 0.024294 | \n",
" 0.835339 | \n",
" 0.782481 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 5.920017 6.864245 0.603656 -2.743017 0.228580 \n",
"std 3.534648 3.845034 0.224679 2.049753 0.370622 \n",
"min 1.477380 1.849392 0.271934 -22.827546 -1.899284 \n",
"25% 2.975700 3.600521 0.502338 -3.631702 0.042875 \n",
"50% 4.169098 5.055890 0.558942 -2.233530 0.309592 \n",
"75% 8.616798 9.809069 0.632651 -1.287602 0.509937 \n",
"max 18.840775 20.371025 3.689853 0.024294 0.835339 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.225978 \n",
"std 0.227965 \n",
"min -0.626938 \n",
"25% 0.088760 \n",
"50% 0.253954 \n",
"75% 0.389390 \n",
"max 0.782481 "
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eva_frame_df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "d1696b4f-1520-4201-b855-63f517022ec3",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 4, 1)\n",
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 2)\n",
" plt.imshow(masked_feature, cmap='gray')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 3)\n",
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.subplot(1, 4, 4)\n",
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
" plt.savefig('./figures/result/vitmae_20_samples.png', bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "56291a37-cc49-428f-a8db-99bdd7a1f062",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, './models/MAE/vit.pt')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "bb8eccaa-7409-4cce-9119-70aed5ee496e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'1114', '1952', '2568', '3523', '602'}"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
"find_ex"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "b9c6cdba-e563-42e2-885d-d1df320dac02",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for j in find_ex:\n",
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
" mask_rev = 1 - mask\n",
" img_in = ori * mask_rev / MAX_VALUE\n",
" img_out = model(torch.tensor(img_in.reshape(1, 1, 96, 96), dtype=torch.float32)).detach().cpu().numpy()[0][0] * MAX_VALUE\n",
" out = ori * mask_rev + img_out * mask\n",
" plt.imshow(out, cmap='RdYlGn_r')\n",
" plt.gca().axis('off')\n",
" plt.savefig(f'./test_img/out_fig/{j}-mae_vit_out.png', bbox_inches='tight')\n",
" plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3f24f2a-bc47-409d-8e46-bfa62851701b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}