1277 lines
44 KiB
Plaintext
1277 lines
44 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "e247809e-7610-487b-88e0-9b4947e92c6b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import os\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"\n",
|
|||
|
"MAX_VALUE = 107.49169921875"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "4c4182a2-0284-4a82-a494-cbe4051ff7bd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"device(type='cuda')"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|||
|
"device"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "b7402720-de05-45e9-b076-04780a513fc3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class GrayScaleDataset(Dataset):\n",
|
|||
|
" def __init__(self, data_dir):\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
|
|||
|
"\n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" return len(self.file_list)\n",
|
|||
|
"\n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
|
|||
|
" data = np.load(file_path)[:,:,0] / MAX_VALUE\n",
|
|||
|
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n",
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = idx % len(self.mask_filenames)\n",
|
|||
|
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "9ec25dc1-3728-4b0b-8403-ccad10355999",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class PatchMasking:\n",
|
|||
|
" def __init__(self, patch_size, mask_ratio):\n",
|
|||
|
" self.patch_size = patch_size\n",
|
|||
|
" self.mask_ratio = mask_ratio\n",
|
|||
|
"\n",
|
|||
|
" def __call__(self, x):\n",
|
|||
|
" batch_size, C, H, W = x.shape\n",
|
|||
|
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
|
|||
|
" num_masked = int(num_patches * self.mask_ratio)\n",
|
|||
|
" \n",
|
|||
|
" # 为每个样本生成独立的mask\n",
|
|||
|
" masks = []\n",
|
|||
|
" for _ in range(batch_size):\n",
|
|||
|
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
|
|||
|
" mask[:num_masked] = 1\n",
|
|||
|
" mask = mask[torch.randperm(num_patches)]\n",
|
|||
|
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
|
|||
|
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
|
|||
|
" masks.append(mask)\n",
|
|||
|
" \n",
|
|||
|
" # 将所有mask堆叠成一个批量张量\n",
|
|||
|
" masks = torch.stack(masks, dim=0)\n",
|
|||
|
" masks = torch.unsqueeze(masks, dim=1)\n",
|
|||
|
" \n",
|
|||
|
" # 应用mask到输入x上\n",
|
|||
|
" masked_x = x * (1 - masks.float())\n",
|
|||
|
" return masked_x, masks"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "a1f70780-9e31-4917-9785-768140e5610e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class MLP(nn.Module):\n",
|
|||
|
" def __init__(self, input_dim, output_dim):\n",
|
|||
|
" super(MLP, self).__init__()\n",
|
|||
|
" self.fc1 = nn.Linear(input_dim, output_dim)\n",
|
|||
|
" self.act = nn.GELU() # 使用 GELU 激活函数\n",
|
|||
|
" self.fc2 = nn.Linear(output_dim, input_dim)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return self.fc2(self.act(self.fc1(x)))\n",
|
|||
|
"\n",
|
|||
|
"class Attention(nn.Module):\n",
|
|||
|
" def __init__(self, dim, heads):\n",
|
|||
|
" super(Attention, self).__init__()\n",
|
|||
|
" self.heads = heads\n",
|
|||
|
" self.dim = dim\n",
|
|||
|
" self.scale = dim ** -0.5\n",
|
|||
|
"\n",
|
|||
|
" self.qkv = nn.Linear(dim, dim * 3)\n",
|
|||
|
" self.attn_drop = nn.Dropout(0.1)\n",
|
|||
|
" self.proj = nn.Linear(dim, dim)\n",
|
|||
|
" self.proj_drop = nn.Dropout(0.1)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" B, N, C = x.shape\n",
|
|||
|
" qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)\n",
|
|||
|
" q, k, v = qkv[0], qkv[1], qkv[2]\n",
|
|||
|
"\n",
|
|||
|
" attn = (q @ k.transpose(-2, -1)) * self.scale\n",
|
|||
|
" attn = attn.softmax(dim=-1)\n",
|
|||
|
" attn = self.attn_drop(attn)\n",
|
|||
|
"\n",
|
|||
|
" out = (attn @ v).transpose(1, 2).reshape(B, N, C)\n",
|
|||
|
" return self.proj_drop(self.proj(out))\n",
|
|||
|
"\n",
|
|||
|
"class LayerNorm(nn.Module):\n",
|
|||
|
" def __init__(self, dim, eps=1e-6):\n",
|
|||
|
" super(LayerNorm, self).__init__()\n",
|
|||
|
" self.ln = nn.LayerNorm(dim, eps=eps)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return self.ln(x)\n",
|
|||
|
"\n",
|
|||
|
"class Dropout(nn.Module):\n",
|
|||
|
" def __init__(self, p=0.1):\n",
|
|||
|
" super(Dropout, self).__init__()\n",
|
|||
|
" self.dropout = nn.Dropout(p)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return self.dropout(x)\n",
|
|||
|
"\n",
|
|||
|
"class ViTEncoder(nn.Module):\n",
|
|||
|
" def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256, dropout=0.1):\n",
|
|||
|
" super(ViTEncoder, self).__init__()\n",
|
|||
|
" self.patch_size = patch_size\n",
|
|||
|
" self.dim = dim\n",
|
|||
|
" self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n",
|
|||
|
"\n",
|
|||
|
" self.attention_layers = nn.ModuleList([\n",
|
|||
|
" nn.ModuleList([\n",
|
|||
|
" LayerNorm(dim), # Layer Normalization\n",
|
|||
|
" Attention(dim, heads),\n",
|
|||
|
" Dropout(dropout), # Dropout\n",
|
|||
|
" LayerNorm(dim), # Layer Normalization\n",
|
|||
|
" MLP(dim, mlp_dim),\n",
|
|||
|
" Dropout(dropout) # Dropout\n",
|
|||
|
" ]) for _ in range(depth)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.patch_embedding(x) # 形状变为 (batch_size, dim, num_patches_h, num_patches_w)\n",
|
|||
|
" x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n",
|
|||
|
"\n",
|
|||
|
" for norm1, attn, drop1, norm2, mlp, drop2 in self.attention_layers:\n",
|
|||
|
" x = x + drop1(attn(norm1(x))) # 残差连接\n",
|
|||
|
" x = x + drop2(mlp(norm2(x))) # 残差连接\n",
|
|||
|
" return x\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class ConvDecoder(nn.Module):\n",
|
|||
|
" def __init__(self, dim=128, patch_size=8, img_size=96):\n",
|
|||
|
" super(ConvDecoder, self).__init__()\n",
|
|||
|
" self.dim = dim\n",
|
|||
|
" self.patch_size = patch_size\n",
|
|||
|
" self.img_size = img_size\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(self.dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
|||
|
" nn.Sigmoid()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n",
|
|||
|
" x = self.decoder(x)\n",
|
|||
|
" return x\n",
|
|||
|
"\n",
|
|||
|
"class MAEModel(nn.Module):\n",
|
|||
|
" def __init__(self, encoder, decoder):\n",
|
|||
|
" super(MAEModel, self).__init__()\n",
|
|||
|
" self.encoder = encoder\n",
|
|||
|
" self.decoder = decoder\n",
|
|||
|
" # self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" encoded = self.encoder(x)\n",
|
|||
|
" decoded = self.decoder(encoded)\n",
|
|||
|
" return decoded"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "4a1427a1-bf38-483e-b92b-07631078c78a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def masked_mse_loss(preds, target, mask):\n",
|
|||
|
" loss = (preds - target) ** 2\n",
|
|||
|
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
|||
|
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
|||
|
" return loss"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "29deee2c-5771-498a-b01b-fde5e0f387ba",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
|
|||
|
" best_model = model\n",
|
|||
|
" best_loss = 100\n",
|
|||
|
" model.to(device)\n",
|
|||
|
" for epoch in range(epochs):\n",
|
|||
|
" model.train()\n",
|
|||
|
" train_loss = 0\n",
|
|||
|
" for data in train_loader:\n",
|
|||
|
" data = data.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
|
|||
|
" output = model(masked_data)\n",
|
|||
|
" loss = masked_mse_loss(output, data, mask)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" train_loss += loss.item()\n",
|
|||
|
" train_loss /= len(train_loader)\n",
|
|||
|
"\n",
|
|||
|
" model.eval()\n",
|
|||
|
" val_loss = 0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for data in val_loader:\n",
|
|||
|
" data = data.to(device)\n",
|
|||
|
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
|
|||
|
" output = model(masked_data)\n",
|
|||
|
" loss = masked_mse_loss(output, data, mask)\n",
|
|||
|
" val_loss += loss.item()\n",
|
|||
|
" val_loss /= len(val_loader)\n",
|
|||
|
" if val_loss < best_loss:\n",
|
|||
|
" best_loss = val_loss\n",
|
|||
|
" best_model = model\n",
|
|||
|
"\n",
|
|||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_dir = './out_mat/96/train/'\n",
|
|||
|
"train_dataset = GrayScaleDataset(train_dir)\n",
|
|||
|
"\n",
|
|||
|
"val_dir = './out_mat/96/valid/'\n",
|
|||
|
"val_dataset = GrayScaleDataset(val_dir)\n",
|
|||
|
"\n",
|
|||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "0d7e2f83-c113-4c62-91eb-d4ea3192530c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 43,
|
|||
|
"id": "7d6d07a4-31f1-4350-a487-b583db979381",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"encoder = ViTEncoder()\n",
|
|||
|
"decoder = ConvDecoder()\n",
|
|||
|
"model = MAEModel(encoder, decoder)\n",
|
|||
|
"criterion = nn.MSELoss()\n",
|
|||
|
"optimizer = optim.Adam(model.parameters(), lr=1e-4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "09b04e16-3257-4890-b736-a6c7274561e0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Epoch 1, Train Loss: 0.9251, Val Loss: 0.0869\n",
|
|||
|
"Epoch 2, Train Loss: 0.0734, Val Loss: 0.0506\n",
|
|||
|
"Epoch 3, Train Loss: 0.0494, Val Loss: 0.0489\n",
|
|||
|
"Epoch 4, Train Loss: 0.0432, Val Loss: 0.0462\n",
|
|||
|
"Epoch 5, Train Loss: 0.0390, Val Loss: 0.0400\n",
|
|||
|
"Epoch 6, Train Loss: 0.0351, Val Loss: 0.0356\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"train_model(model, train_loader, val_loader, epochs=50, criterion=criterion, optimizer=optimizer, device=device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
|
|||
|
"test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "56653f37-899a-47d6-8d50-e456b4ad1835",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "2504300a-ac91-453a-9bfb-ab89f56d4ff6",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"id": "d55be844-0873-4d9a-8160-22603de32a81",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"test_set2 = GrayScaleDataset('./out_mat/96/test/')\n",
|
|||
|
"test_loader2 = DataLoader(test_set2, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"id": "f83dbc37-8543-45bc-ba59-ca88d4ba2a66",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"torch.Size([64, 96, 96])"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"rev_data.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"id": "2655f7f4-9d88-49fd-9346-87a621320183",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for data in test_loader2:\n",
|
|||
|
" data = data.to(device)\n",
|
|||
|
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
|
|||
|
" output = model(masked_data)\n",
|
|||
|
" rev_data = data * MAX_VALUE\n",
|
|||
|
" rev_recon = output * MAX_VALUE\n",
|
|||
|
" data_label = rev_data * mask\n",
|
|||
|
" data_label = data_label[mask==1]\n",
|
|||
|
" recon_no2 = rev_recon * mask\n",
|
|||
|
" recon_no2 = recon_no2[mask==1]\n",
|
|||
|
" y_true = rev_data.flatten()\n",
|
|||
|
" y_pred = rev_recon.flatten()\n",
|
|||
|
" mae = mean_absolute_error(y_true, y_pred)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
|
|||
|
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
|
|||
|
" r2 = r2_score(y_true, y_pred)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 36,
|
|||
|
"id": "0327b51c-d714-4fe0-a044-d8be3ff180e0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" <td>75.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.208013</td>\n",
|
|||
|
" <td>1.600644</td>\n",
|
|||
|
" <td>0.142720</td>\n",
|
|||
|
" <td>0.941983</td>\n",
|
|||
|
" <td>0.981683</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.056235</td>\n",
|
|||
|
" <td>0.081791</td>\n",
|
|||
|
" <td>0.003435</td>\n",
|
|||
|
" <td>0.004449</td>\n",
|
|||
|
" <td>0.002309</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>1.091517</td>\n",
|
|||
|
" <td>1.446389</td>\n",
|
|||
|
" <td>0.134849</td>\n",
|
|||
|
" <td>0.911833</td>\n",
|
|||
|
" <td>0.965708</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.170305</td>\n",
|
|||
|
" <td>1.555051</td>\n",
|
|||
|
" <td>0.140519</td>\n",
|
|||
|
" <td>0.940425</td>\n",
|
|||
|
" <td>0.981100</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.204728</td>\n",
|
|||
|
" <td>1.593261</td>\n",
|
|||
|
" <td>0.142981</td>\n",
|
|||
|
" <td>0.942651</td>\n",
|
|||
|
" <td>0.982003</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.242762</td>\n",
|
|||
|
" <td>1.646311</td>\n",
|
|||
|
" <td>0.145185</td>\n",
|
|||
|
" <td>0.944118</td>\n",
|
|||
|
" <td>0.982809</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>1.420721</td>\n",
|
|||
|
" <td>2.037903</td>\n",
|
|||
|
" <td>0.150566</td>\n",
|
|||
|
" <td>0.949663</td>\n",
|
|||
|
" <td>0.984610</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa\n",
|
|||
|
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
|||
|
"mean 1.208013 1.600644 0.142720 0.941983 0.981683\n",
|
|||
|
"std 0.056235 0.081791 0.003435 0.004449 0.002309\n",
|
|||
|
"min 1.091517 1.446389 0.134849 0.911833 0.965708\n",
|
|||
|
"25% 1.170305 1.555051 0.140519 0.940425 0.981100\n",
|
|||
|
"50% 1.204728 1.593261 0.142981 0.942651 0.982003\n",
|
|||
|
"75% 1.242762 1.646311 0.145185 0.944118 0.982809\n",
|
|||
|
"max 1.420721 2.037903 0.150566 0.949663 0.984610"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 36,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 38,
|
|||
|
"id": "b8669b7e-6974-418a-87fc-074734f9a1a3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for data in test_loader2:\n",
|
|||
|
" data = data.to(device)\n",
|
|||
|
" masked_data, mask = PatchMasking(patch_size=8, mask_ratio=0.2)(data)\n",
|
|||
|
" output = model(masked_data)\n",
|
|||
|
" rev_data = data * MAX_VALUE\n",
|
|||
|
" rev_recon = output * MAX_VALUE\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 39,
|
|||
|
"id": "c9079fc5-6ab3-465e-9067-6cad8f69c5a8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.714315</td>\n",
|
|||
|
" <td>2.350189</td>\n",
|
|||
|
" <td>0.215974</td>\n",
|
|||
|
" <td>0.609470</td>\n",
|
|||
|
" <td>0.943560</td>\n",
|
|||
|
" <td>0.823401</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.697344</td>\n",
|
|||
|
" <td>0.940345</td>\n",
|
|||
|
" <td>0.077893</td>\n",
|
|||
|
" <td>0.131496</td>\n",
|
|||
|
" <td>0.022261</td>\n",
|
|||
|
" <td>0.069394</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.636049</td>\n",
|
|||
|
" <td>0.821723</td>\n",
|
|||
|
" <td>0.099999</td>\n",
|
|||
|
" <td>0.003194</td>\n",
|
|||
|
" <td>0.802237</td>\n",
|
|||
|
" <td>0.405363</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.121617</td>\n",
|
|||
|
" <td>1.576669</td>\n",
|
|||
|
" <td>0.170974</td>\n",
|
|||
|
" <td>0.533081</td>\n",
|
|||
|
" <td>0.931653</td>\n",
|
|||
|
" <td>0.783616</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.459720</td>\n",
|
|||
|
" <td>2.132316</td>\n",
|
|||
|
" <td>0.199419</td>\n",
|
|||
|
" <td>0.623769</td>\n",
|
|||
|
" <td>0.946952</td>\n",
|
|||
|
" <td>0.831403</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>2.334761</td>\n",
|
|||
|
" <td>3.119393</td>\n",
|
|||
|
" <td>0.234517</td>\n",
|
|||
|
" <td>0.698517</td>\n",
|
|||
|
" <td>0.958943</td>\n",
|
|||
|
" <td>0.872422</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>4.406258</td>\n",
|
|||
|
" <td>8.470109</td>\n",
|
|||
|
" <td>1.242636</td>\n",
|
|||
|
" <td>0.895199</td>\n",
|
|||
|
" <td>0.986901</td>\n",
|
|||
|
" <td>0.965110</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 1.714315 2.350189 0.215974 0.609470 0.943560 \n",
|
|||
|
"std 0.697344 0.940345 0.077893 0.131496 0.022261 \n",
|
|||
|
"min 0.636049 0.821723 0.099999 0.003194 0.802237 \n",
|
|||
|
"25% 1.121617 1.576669 0.170974 0.533081 0.931653 \n",
|
|||
|
"50% 1.459720 2.132316 0.199419 0.623769 0.946952 \n",
|
|||
|
"75% 2.334761 3.119393 0.234517 0.698517 0.958943 \n",
|
|||
|
"max 4.406258 8.470109 1.242636 0.895199 0.986901 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.823401 \n",
|
|||
|
"std 0.069394 \n",
|
|||
|
"min 0.405363 \n",
|
|||
|
"25% 0.783616 \n",
|
|||
|
"50% 0.831403 \n",
|
|||
|
"75% 0.872422 \n",
|
|||
|
"max 0.965110 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 39,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')\n",
|
|||
|
"eva_frame_df.describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "e840b789-bf68-4b4d-a8d3-c5362c310349",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "ValueError",
|
|||
|
"evalue": "too many values to unpack (expected 3)",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|||
|
"Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_idx, (X, y, mask) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(test_loader2):\n\u001b[1;32m 6\u001b[0m X, y, mask \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mto(device), y\u001b[38;5;241m.\u001b[39mto(device), mask\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 7\u001b[0m mask_rev \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39msqueeze(mask, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# mask取反获得修复区域\u001b[39;00m\n",
|
|||
|
"\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 3)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader2):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = torch.squeeze(y * MAX_VALUE, dim=1)\n",
|
|||
|
" rev_recon = torch.squeeze(reconstructed * MAX_VALUE, dim=1)\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" data_label = rev_data * mask_rev\n",
|
|||
|
" data_label = data_label[mask_rev==1]\n",
|
|||
|
" recon_no2 = rev_recon * mask_rev\n",
|
|||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
|||
|
" y_true = rev_data.flatten()\n",
|
|||
|
" y_pred = rev_recon.flatten()\n",
|
|||
|
" mae = mean_absolute_error(y_true, y_pred)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
|
|||
|
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
|
|||
|
" r2 = r2_score(y_true, y_pred)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "41fa754d-1eee-43a2-9e39-a0254719be30",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>149.000000</td>\n",
|
|||
|
" <td>149.000000</td>\n",
|
|||
|
" <td>149.000000</td>\n",
|
|||
|
" <td>149.000000</td>\n",
|
|||
|
" <td>149.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>2.235662</td>\n",
|
|||
|
" <td>4.042349</td>\n",
|
|||
|
" <td>0.238494</td>\n",
|
|||
|
" <td>0.626060</td>\n",
|
|||
|
" <td>0.572341</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.192709</td>\n",
|
|||
|
" <td>0.357475</td>\n",
|
|||
|
" <td>0.007405</td>\n",
|
|||
|
" <td>0.042890</td>\n",
|
|||
|
" <td>0.042652</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>1.786567</td>\n",
|
|||
|
" <td>3.167143</td>\n",
|
|||
|
" <td>0.224796</td>\n",
|
|||
|
" <td>0.522157</td>\n",
|
|||
|
" <td>0.460707</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>2.084117</td>\n",
|
|||
|
" <td>3.779276</td>\n",
|
|||
|
" <td>0.232974</td>\n",
|
|||
|
" <td>0.597774</td>\n",
|
|||
|
" <td>0.547144</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>2.226062</td>\n",
|
|||
|
" <td>4.075465</td>\n",
|
|||
|
" <td>0.237429</td>\n",
|
|||
|
" <td>0.627588</td>\n",
|
|||
|
" <td>0.570579</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>2.361411</td>\n",
|
|||
|
" <td>4.284523</td>\n",
|
|||
|
" <td>0.243866</td>\n",
|
|||
|
" <td>0.656226</td>\n",
|
|||
|
" <td>0.601233</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>2.751377</td>\n",
|
|||
|
" <td>4.917407</td>\n",
|
|||
|
" <td>0.258230</td>\n",
|
|||
|
" <td>0.740943</td>\n",
|
|||
|
" <td>0.666083</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa\n",
|
|||
|
"count 149.000000 149.000000 149.000000 149.000000 149.000000\n",
|
|||
|
"mean 2.235662 4.042349 0.238494 0.626060 0.572341\n",
|
|||
|
"std 0.192709 0.357475 0.007405 0.042890 0.042652\n",
|
|||
|
"min 1.786567 3.167143 0.224796 0.522157 0.460707\n",
|
|||
|
"25% 2.084117 3.779276 0.232974 0.597774 0.547144\n",
|
|||
|
"50% 2.226062 4.075465 0.237429 0.627588 0.570579\n",
|
|||
|
"75% 2.361411 4.284523 0.243866 0.656226 0.601233\n",
|
|||
|
"max 2.751377 4.917407 0.258230 0.740943 0.666083"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"id": "46cffa4a-37bc-4e13-9723-fc6cb244c95c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list_frame = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"best_mape = 1\n",
|
|||
|
"best_img = None\n",
|
|||
|
"best_mask = None\n",
|
|||
|
"best_recov = None\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * MAX_VALUE\n",
|
|||
|
" rev_recon = reconstructed * MAX_VALUE\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "b6f6a897-2f48-4958-8725-f566430c61e1",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_frame_df = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "6d1920e7-b92f-414e-8273-0b4666587904",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" <td>4739.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>5.920017</td>\n",
|
|||
|
" <td>6.864245</td>\n",
|
|||
|
" <td>0.603656</td>\n",
|
|||
|
" <td>-2.743017</td>\n",
|
|||
|
" <td>0.228580</td>\n",
|
|||
|
" <td>0.225978</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>3.534648</td>\n",
|
|||
|
" <td>3.845034</td>\n",
|
|||
|
" <td>0.224679</td>\n",
|
|||
|
" <td>2.049753</td>\n",
|
|||
|
" <td>0.370622</td>\n",
|
|||
|
" <td>0.227965</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>1.477380</td>\n",
|
|||
|
" <td>1.849392</td>\n",
|
|||
|
" <td>0.271934</td>\n",
|
|||
|
" <td>-22.827546</td>\n",
|
|||
|
" <td>-1.899284</td>\n",
|
|||
|
" <td>-0.626938</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>2.975700</td>\n",
|
|||
|
" <td>3.600521</td>\n",
|
|||
|
" <td>0.502338</td>\n",
|
|||
|
" <td>-3.631702</td>\n",
|
|||
|
" <td>0.042875</td>\n",
|
|||
|
" <td>0.088760</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>4.169098</td>\n",
|
|||
|
" <td>5.055890</td>\n",
|
|||
|
" <td>0.558942</td>\n",
|
|||
|
" <td>-2.233530</td>\n",
|
|||
|
" <td>0.309592</td>\n",
|
|||
|
" <td>0.253954</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>8.616798</td>\n",
|
|||
|
" <td>9.809069</td>\n",
|
|||
|
" <td>0.632651</td>\n",
|
|||
|
" <td>-1.287602</td>\n",
|
|||
|
" <td>0.509937</td>\n",
|
|||
|
" <td>0.389390</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>18.840775</td>\n",
|
|||
|
" <td>20.371025</td>\n",
|
|||
|
" <td>3.689853</td>\n",
|
|||
|
" <td>0.024294</td>\n",
|
|||
|
" <td>0.835339</td>\n",
|
|||
|
" <td>0.782481</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
|||
|
"mean 5.920017 6.864245 0.603656 -2.743017 0.228580 \n",
|
|||
|
"std 3.534648 3.845034 0.224679 2.049753 0.370622 \n",
|
|||
|
"min 1.477380 1.849392 0.271934 -22.827546 -1.899284 \n",
|
|||
|
"25% 2.975700 3.600521 0.502338 -3.631702 0.042875 \n",
|
|||
|
"50% 4.169098 5.055890 0.558942 -2.233530 0.309592 \n",
|
|||
|
"75% 8.616798 9.809069 0.632651 -1.287602 0.509937 \n",
|
|||
|
"max 18.840775 20.371025 3.689853 0.024294 0.835339 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4739.000000 \n",
|
|||
|
"mean 0.225978 \n",
|
|||
|
"std 0.227965 \n",
|
|||
|
"min -0.626938 \n",
|
|||
|
"25% 0.088760 \n",
|
|||
|
"50% 0.253954 \n",
|
|||
|
"75% 0.389390 \n",
|
|||
|
"max 0.782481 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"eva_frame_df.describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "d1696b4f-1520-4201-b855-63f517022ec3",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_rst(input_feature,masked_feature, recov_region, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 4, 1)\n",
|
|||
|
" plt.imshow(input_feature, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.subplot(1, 4, 2)\n",
|
|||
|
" plt.imshow(masked_feature, cmap='gray')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.subplot(1, 4, 3)\n",
|
|||
|
" plt.imshow(recov_region, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.subplot(1, 4, 4)\n",
|
|||
|
" plt.imshow(output_feature, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off') # 获取当前坐标轴并关闭\n",
|
|||
|
" plt.savefig('./figures/result/vitmae_20_samples.png', bbox_inches='tight')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 92,
|
|||
|
"id": "56291a37-cc49-428f-a8db-99bdd7a1f062",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"torch.save(model, './models/MAE/vit.pt')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 41,
|
|||
|
"id": "bb8eccaa-7409-4cce-9119-70aed5ee496e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'1114', '1952', '2568', '3523', '602'}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 41,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
|
|||
|
"find_ex"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 42,
|
|||
|
"id": "b9c6cdba-e563-42e2-885d-d1df320dac02",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 0 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for j in find_ex:\n",
|
|||
|
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
|
|||
|
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
|
|||
|
" mask_rev = 1 - mask\n",
|
|||
|
" img_in = ori * mask_rev / MAX_VALUE\n",
|
|||
|
" img_out = model(torch.tensor(img_in.reshape(1, 1, 96, 96), dtype=torch.float32)).detach().cpu().numpy()[0][0] * MAX_VALUE\n",
|
|||
|
" out = ori * mask_rev + img_out * mask\n",
|
|||
|
" plt.imshow(out, cmap='RdYlGn_r')\n",
|
|||
|
" plt.gca().axis('off')\n",
|
|||
|
" plt.savefig(f'./test_img/out_fig/{j}-mae_vit_out.png', bbox_inches='tight')\n",
|
|||
|
" plt.clf()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "e3f24f2a-bc47-409d-8e46-bfa62851701b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|