MAE_ATMO/torch_MAE_1d_encoder.ipynb

983 lines
72 KiB
Plaintext
Raw Permalink Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "15b9ced8-7282-4f97-a079-f31bf9405145",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fc9d487f810>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(0)\n",
"torch.random.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7f83e6c7-8207-41b3-908b-6b1fad78ecd5",
"metadata": {},
"outputs": [],
"source": [
"max_pixel_value = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c66f2b9f-fcad-4237-abb2-d7f918d74116",
"metadata": {},
"outputs": [],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/20/'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3354304-f6de-44bf-adbf-bbff557a8c93",
"metadata": {},
"outputs": [],
"source": [
"train_set = NO2Dataset(image_dir, mask_dir)\n",
"train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n",
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "645114e8-65a4-4867-b3fe-23395288e855",
"metadata": {},
"outputs": [],
"source": [
"class Conv(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
" super(Conv, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
" bias=False):\n",
" super(ConvBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
" norm_layer(out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "31ecf247-e98b-4977-a145-782914a042bd",
"metadata": {},
"outputs": [],
"source": [
"class SeparableBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
" super(SeparableBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
" # 分离卷积,仅调整空间信息\n",
" norm_layer(in_channels), # 对输入通道进行归一化\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
" nn.ReLU6()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" # 如果输入和输出通道不一致,进行降采样操作\n",
" self.downsample = downsample\n",
" if in_channels != out_channels or stride != 1:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(out_channels)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttentionBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
" super(MultiHeadAttentionBlock, self).__init__()\n",
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
" B, C, H, W = x.shape\n",
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
"\n",
" # Apply multihead attention\n",
" attn_output, _ = self.attention(x, x, x)\n",
"\n",
" # Apply normalization and dropout\n",
" attn_output = self.norm(attn_output)\n",
" attn_output = self.dropout(attn_output)\n",
"\n",
" # Reshape back to (B, C, H, W)\n",
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
"\n",
" return attn_output"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
"metadata": {},
"outputs": [],
"source": [
"class SpatialAttentionBlock(nn.Module):\n",
" def __init__(self):\n",
" super(SpatialAttentionBlock, self).__init__()\n",
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
"\n",
" def forward(self, x): #(B, 64, H, W)\n",
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
" return x * out #(B, C, H, W)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
"metadata": {},
"outputs": [],
"source": [
"class DecoderAttentionBlock(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(DecoderAttentionBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
" self.spatial_attention = SpatialAttentionBlock()\n",
"\n",
" def forward(self, x):\n",
" # 通道注意力\n",
" b, c, h, w = x.size()\n",
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
"\n",
" avg_out = self.conv1(avg_pool)\n",
" max_out = self.conv1(max_pool)\n",
"\n",
" out = avg_out + max_out\n",
" out = torch.sigmoid(self.conv2(out))\n",
"\n",
" # 添加空间注意力\n",
" out = x * out\n",
" out = self.spatial_attention(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "a382ed1b-cc88-4f03-95c2-843981ee81f1",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" Conv(1, 32, kernel_size=3, stride=2),\n",
" nn.ReLU(),\n",
" SEBlock(32,32),\n",
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
" ResidualBlock(64,64),\n",
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
" SEBlock(128, 128)\n",
" )\n",
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" # DecoderAttentionBlock(32),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" # DecoderAttentionBlock(16),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(X)\n",
" # loss = criterion(reconstructed, y)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" reconstructed = model(X)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(y))\n",
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" # loss = criterion(reconstructed, y)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6094b6c8-8211-4557-9944-7eef977ea9ec",
"metadata": {},
"outputs": [],
"source": [
"def masked_mae_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "781f558e-d41c-4721-94fd-564cd6c2b347",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 0.013549078723781131, Val Loss: 0.014539383435204847\n",
"Epoch 2, Train Loss: 0.013641111095966192, Val Loss: 0.014635173200782555\n",
"Epoch 3, Train Loss: 0.013503858572290988, Val Loss: 0.01476309893291388\n",
"Epoch 4, Train Loss: 0.013455510417970887, Val Loss: 0.014315864057349624\n",
"Epoch 5, Train Loss: 0.01339626228704193, Val Loss: 0.01442837900023407\n",
"Epoch 6, Train Loss: 0.013295360569035608, Val Loss: 0.015184532503472336\n",
"Epoch 12, Train Loss: 0.012901031857793125, Val Loss: 0.013935101566030018\n",
"Epoch 13, Train Loss: 0.01295265725158761, Val Loss: 0.013862666924164366\n",
"Epoch 14, Train Loss: 0.013010161795149865, Val Loss: 0.013880979492148357\n",
"Epoch 15, Train Loss: 0.012936625905940977, Val Loss: 0.013813913021403463\n",
"Epoch 16, Train Loss: 0.01287072714926167, Val Loss: 0.01403502803017844\n",
"Epoch 17, Train Loss: 0.012832806871214695, Val Loss: 0.014388528165977393\n",
"Epoch 18, Train Loss: 0.012794200125992583, Val Loss: 0.01383661480147892\n",
"Epoch 19, Train Loss: 0.01294981115208003, Val Loss: 0.01408140508652623\n",
"Epoch 20, Train Loss: 0.012662894464583631, Val Loss: 0.01359965718949019\n",
"Test Loss: 0.007365767304242279\n"
]
}
],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 20\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fc8e0717100>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj0AAAGdCAYAAAD5ZcJyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABwtUlEQVR4nO3dd3gU9f728femE0gBQhJ6J7SQUEMHBaki2EBEAftRrFj5PSqWo1g5qKCo5wh2sACiIIg0EVCkhE7ohJJCTSV9nj+GbBJIQjZtN8n9uq5cGXZnZz7LGnPzrRbDMAxEREREKjknexcgIiIiUh4UekRERKRKUOgRERGRKkGhR0RERKoEhR4RERGpEhR6REREpEpQ6BEREZEqQaFHREREqgQXexfgSLKysjh16hReXl5YLBZ7lyMiIiJFYBgGCQkJ1KtXDyengttzFHpyOXXqFA0bNrR3GSIiIlIMx48fp0GDBgU+r9CTi5eXF2D+pXl7e9u5GhERESmK+Ph4GjZsaP09XhCFnlyyu7S8vb0VekRERCqYqw1N0UBmERERqRIUekRERKRKUOgRERGRKkFjekREpFIzDIOMjAwyMzPtXYoUk7OzMy4uLiVeTkahR0REKq20tDSioqJITk62dylSQp6entStWxc3N7diX0OhR0REKqWsrCyOHDmCs7Mz9erVw83NTQvPVkCGYZCWlsbp06c5cuQILVu2LHQBwsIo9IiISKWUlpZGVlYWDRs2xNPT097lSAlUq1YNV1dXjh07RlpaGh4eHsW6jgYyi4hIpVbcVgFxLKXxOeq/BBEREakSFHpERESkSlDoERERqcSaNGnCjBkzSuVaa9aswWKxcOHChVK5XnnTQGYREREH079/f0JDQ0slrPzzzz9Ur1695EVVAmrpESmJE1tg29f2rkJEqpjsBReLok6dOpq9dolCj0hxGQZ8Nx5+eghObLZ3NSJyFYZhkJyWYZcvwzCKXOfEiRNZu3Yt7733HhaLBYvFwty5c7FYLPz666907twZd3d3/vzzTw4dOsTIkSMJCAigRo0adO3ald9//z3P9S7v3rJYLPz3v//lxhtvxNPTk5YtW7J48eJi/73++OOPtGvXDnd3d5o0acK7776b5/kPP/yQli1b4uHhQUBAALfccov1uR9++IHg4GCqVatG7dq1GThwIElJScWu5WrUvSVSXOePQPwJ8zhqOzToYt96RKRQF9Mzafvicrvce88rg/F0K9qv3Pfee4/9+/fTvn17XnnlFQB2794NwHPPPcc777xDs2bNqFmzJsePH2fYsGG89tpruLu788UXXzBixAgiIiJo1KhRgfd4+eWXeeutt3j77bf54IMPGDduHMeOHaNWrVo2va8tW7YwevRoXnrpJcaMGcOGDRt46KGHqF27NhMnTmTz5s08+uijfPnll/Ts2ZNz586xbt06AKKiohg7dixvvfUWN954IwkJCaxbt86mgGgrhR6R4or8K+f49D771SEilYqPjw9ubm54enoSGBgIwL595v9jXnnlFa677jrrubVq1SIkJMT651dffZWFCxeyePFiHn744QLvMXHiRMaOHQvA66+/zvvvv8+mTZsYMmSITbVOnz6dAQMG8MILLwDQqlUr9uzZw9tvv83EiROJjIykevXqXH/99Xh5edG4cWM6duwImKEnIyODm266icaNGwMQHBxs0/1tpdAjUlyRG3OOFXpEHF41V2f2vDLYbvcuDV265G1RTkxM5KWXXmLJkiXWEHHx4kUiIyMLvU6HDh2sx9WrV8fb25vY2Fib69m7dy8jR47M81ivXr2YMWMGmZmZXHfddTRu3JhmzZoxZMgQhgwZYu1WCwkJYcCAAQQHBzN48GAGDRrELbfcQs2aNW2uo6g0pkekuI7lCj2xCj0ijs5iseDp5mKXr9La8+vyWVhPPfUUCxcu5PXXX2fdunWEh4cTHBxMWlpaoddxdXW94u8mKyurVGrMzcvLi61bt/Ltt99St25dXnzxRUJCQrhw4QLOzs6sWLGCX3/9lbZt2/LBBx8QFBTEkSNHSr2ObAo9IsWRdAbOHsj151hIPme/ekSkUnFzcyMzM/Oq561fv56JEydy4403EhwcTGBgIEePHi37Ai9p06YN69evv6KmVq1a4exstm65uLgwcOBA3nrrLXbs2MHRo0dZtWoVYIatXr168fLLL7Nt2zbc3NxYuHBhmdWr7i2R4sgez1OnDaQlQVwkxO6FJr3sW5eIVApNmjTh77//5ujRo9SoUaPAVpiWLVuyYMECRowYgcVi4YUXXiiTFpuCPPnkk3Tt2pVXX32VMWPGsHHjRmbOnMmHH34IwC+//MLhw4fp27cvNWvWZOnSpWRlZREUFMTff//NypUrGTRoEP7+/vz999+cPn2aNm3alFm9aukRKY7s8TyNukOdIPNY43pEpJQ89dRTODs707ZtW+rUqVPgGJ3p06dTs2ZNevbsyYgRIxg8eDCdOnUqtzo7derEd999x7x582jfvj0vvvgir7zyChMnTgTA19eXBQsWcO2119KmTRtmz57Nt99+S7t27fD29uaPP/5g2LBhtGrViueff553332XoUOHllm9FqMs54ZVMPHx8fj4+BAXF4e3t7e9yxFH9um1cHIL3PgJxOyEDR9At/th2Nv2rkxELklJSeHIkSM0bdoUDw8Pe5cjJVTY51nU39/q3hKxVVqSuS4PQOMekHVpVdTYvfarSURErkrdWyK2OrnFDDre9cGnIfi3Nh9X95aIVHD/+te/qFGjRr5f//rXv+xdXomppUfEVsdyjeexWMDv0piepNOQdBaq17ZfbSIiJfDKK6/w1FNP5ftcZRj2odAjYivrIOYe5nf3GuDTyJzBdXofVNcMLhGpmPz9/fH397d3GWVG3VsitsjMgBP/mMfZoQdydXFpXI+IiKNS6BGxRcxOSEsEdx/wz7WWRJ1LoUcrM4uIOCyFHhFbZC9K2CgMnHLtpVNHg5lFRBydQo+ILXIvSpibZnCJiDg8hR6RojKMXC09PfI+d/kMLhERcTgKPSJFde4wJMaAsxvUu2yZd/ca4NvIPFZrj4jYWZMmTZgxY0aRzrVYLCxatKhM63EUCj0iRZXdylOvE7jms6R9Hc3gEhFxZAo9IkVV0HiebJrBJSLi0IoVembNmkWTJk3w8PAgLCyMTZs2FXr+999/T+vWrfHw8CA4OJilS5fmeX7BggUMGjSI2rVrY7FYCA8Pv+Ia/fv3x2Kx5Pm6fEnsyMhIhg8fjqenJ/7+/jz99NNkZGQU5y2KXOnyRQkvlz2FXd1bIo7JMMy98+zxZcPe3p988gn16tUjKysrz+MjR47k7rvv5tChQ4wcOZKAgABq1KhB165d+f3330vtr2nnzp1ce+21VKtWjdq1a3P//feTmJhofX7NmjV069aN6tWr4+vrS69evTh27BgA27dv55prrsHLywtvb286d+7M5s2bS622krJ5Reb58+czefJkZs+eTVhYGDNmzGDw4MFERETku4rjhg0bGDt2LNOmTeP666/nm2++YdSoUWzdupX27dsDkJSURO/evRk9ejT33Xdfgfe+7777eOWVV6x/9vT0tB5nZmYyfPhwAgMD2bBhA1FRUYwfPx5XV1def/11W9+mSF6Jp+HsQfO4UVj+59S5NJhZoUfEMaUnw+v17HPv/zsFbtWLdOqtt97KI488wurVqxkwYAAA586dY9myZSxdupTExESGDRvGa6+9hru7O1988QUjRowgIiKCRo0alajMpKQkBg8eTI8ePfjnn3+IjY3l3nvv5eGHH2bu3LlkZGQwatQo7rvvPr799lvS0tLYtGkTFosFgHHjxtGxY0c++ugjnJ2dCQ8Px9XVtUQ1lSabQ8/06dO57777uOuuuwCYPXs2S5Ys4bPPPuO555674vz33nuPIUOG8PTTTwPw6quvsmLFCmbOnMns2bMBuPPOOwE4evRooff29PQkMDAw3+d+++039uzZw++//05AQAChoaG8+uqrPPvss7z00ku4ubnZ+lZFchy/NJ7Hvy1Uq5n/OdqDS0RKQc2aNRk
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1f48acd7-70e8-46db-9148-6a2df3153f08",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "313fa420-c856-4db1-80ae-b543e1fb73ef",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"model = model.to('cpu')\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" eva_list.append([mae, rmse, mape, r2])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "5c6d5e5a-90f6-4e9a-882f-c2f160b0cb15",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.296906</td>\n",
" <td>2.022362</td>\n",
" <td>0.167694</td>\n",
" <td>0.904339</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.075761</td>\n",
" <td>0.137041</td>\n",
" <td>0.013171</td>\n",
" <td>0.010395</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.121284</td>\n",
" <td>1.716275</td>\n",
" <td>0.143667</td>\n",
" <td>0.875878</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.238378</td>\n",
" <td>1.917907</td>\n",
" <td>0.156429</td>\n",
" <td>0.898060</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.287193</td>\n",
" <td>2.011828</td>\n",
" <td>0.166679</td>\n",
" <td>0.904941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.353045</td>\n",
" <td>2.102409</td>\n",
" <td>0.176996</td>\n",
" <td>0.911137</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.446046</td>\n",
" <td>2.414532</td>\n",
" <td>0.202142</td>\n",
" <td>0.924070</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2\n",
"count 75.000000 75.000000 75.000000 75.000000\n",
"mean 1.296906 2.022362 0.167694 0.904339\n",
"std 0.075761 0.137041 0.013171 0.010395\n",
"min 1.121284 1.716275 0.143667 0.875878\n",
"25% 1.238378 1.917907 0.156429 0.898060\n",
"50% 1.287193 2.011828 0.166679 0.904941\n",
"75% 1.353045 2.102409 0.176996 0.911137\n",
"max 1.446046 2.414532 0.202142 0.924070"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "b4250d45-b430-40a0-ace7-f59d3451aebd",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "039d0041-4573-4645-aeb0-686eabfe8b6f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.306817</td>\n",
" <td>1.845819</td>\n",
" <td>0.166876</td>\n",
" <td>0.670519</td>\n",
" <td>0.886646</td>\n",
" <td>0.836323</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.623645</td>\n",
" <td>0.902619</td>\n",
" <td>0.107025</td>\n",
" <td>0.240752</td>\n",
" <td>0.111142</td>\n",
" <td>0.121726</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.432991</td>\n",
" <td>0.568319</td>\n",
" <td>0.050612</td>\n",
" <td>-1.539424</td>\n",
" <td>-0.267569</td>\n",
" <td>0.022258</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.835579</td>\n",
" <td>1.172322</td>\n",
" <td>0.113302</td>\n",
" <td>0.583713</td>\n",
" <td>0.864756</td>\n",
" <td>0.794922</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.161710</td>\n",
" <td>1.658195</td>\n",
" <td>0.143386</td>\n",
" <td>0.735860</td>\n",
" <td>0.921341</td>\n",
" <td>0.869860</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.617382</td>\n",
" <td>2.299731</td>\n",
" <td>0.185039</td>\n",
" <td>0.827242</td>\n",
" <td>0.951285</td>\n",
" <td>0.916741</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>5.338230</td>\n",
" <td>9.936951</td>\n",
" <td>1.929986</td>\n",
" <td>0.983208</td>\n",
" <td>0.995767</td>\n",
" <td>0.992588</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 1.306817 1.845819 0.166876 0.670519 0.886646 \n",
"std 0.623645 0.902619 0.107025 0.240752 0.111142 \n",
"min 0.432991 0.568319 0.050612 -1.539424 -0.267569 \n",
"25% 0.835579 1.172322 0.113302 0.583713 0.864756 \n",
"50% 1.161710 1.658195 0.143386 0.735860 0.921341 \n",
"75% 1.617382 2.299731 0.185039 0.827242 0.951285 \n",
"max 5.338230 9.936951 1.929986 0.983208 0.995767 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.836323 \n",
"std 0.121726 \n",
"min 0.022258 \n",
"25% 0.794922 \n",
"50% 0.869860 \n",
"75% 0.916741 \n",
"max 0.992588 "
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83c7e465-bbd0-4c56-8cb4-9d1122fe695f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}