1170 lines
77 KiB
Plaintext
1170 lines
77 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"import torch.optim as optim\n",
|
|||
|
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import cv2"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "b8a8cedd-536d-4a48-a1af-7c40489ef0f8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<torch._C.Generator at 0x7f7419087810>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"np.random.seed(42)\n",
|
|||
|
"torch.random.manual_seed(42)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Maximum pixel value in the dataset: 92.64960479736328\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 定义函数来找到最大值\n",
|
|||
|
"def find_max_pixel_value(image_dir):\n",
|
|||
|
" max_pixel_value = 0.0\n",
|
|||
|
" for filename in os.listdir(image_dir):\n",
|
|||
|
" if filename.endswith('.npy'):\n",
|
|||
|
" image_path = os.path.join(image_dir, filename)\n",
|
|||
|
" image = np.load(image_path).astype(np.float32)\n",
|
|||
|
" max_pixel_value = max(max_pixel_value, image.max())\n",
|
|||
|
" return max_pixel_value\n",
|
|||
|
"\n",
|
|||
|
"# 计算图像数据中的最大像素值\n",
|
|||
|
"image_dir = './2022data/new_train_2021/train/' \n",
|
|||
|
"max_pixel_value = find_max_pixel_value(image_dir)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"checkpoint before Generator is OK\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class NO2Dataset(Dataset):\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, image_dir, mask_dir):\n",
|
|||
|
" \n",
|
|||
|
" self.image_dir = image_dir\n",
|
|||
|
" self.mask_dir = mask_dir\n",
|
|||
|
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
|||
|
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" \n",
|
|||
|
" return len(self.image_filenames)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" \n",
|
|||
|
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
|||
|
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
|||
|
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
|||
|
"\n",
|
|||
|
" # 加载图像数据 (.npy 文件)\n",
|
|||
|
" image = np.expand_dims(np.load(image_path).astype(np.float32), axis=2) / max_pixel_value # 形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 加载掩码数据 (.jpg 文件)\n",
|
|||
|
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
|||
|
"\n",
|
|||
|
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
|||
|
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
|||
|
"\n",
|
|||
|
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
|||
|
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 应用掩码\n",
|
|||
|
" masked_image = image.copy()\n",
|
|||
|
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
|||
|
"\n",
|
|||
|
" # cGAN的输入和目标\n",
|
|||
|
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
|||
|
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
|||
|
"\n",
|
|||
|
" # 转换形状为 (channels, height, width)\n",
|
|||
|
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
|||
|
"\n",
|
|||
|
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
|||
|
"\n",
|
|||
|
"# 实例化数据集和数据加载器\n",
|
|||
|
"image_dir = './2022data/new_train_2021/train/'\n",
|
|||
|
"mask_dir = './2022data/new_train_2021/mask/20/'\n",
|
|||
|
"\n",
|
|||
|
"print(f\"checkpoint before Generator is OK\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"id": "41da7319-9795-441d-bde8-8cf390365099",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_set = NO2Dataset(image_dir, mask_dir)\n",
|
|||
|
"train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n",
|
|||
|
"val_set = NO2Dataset('./2022data/new_train_2021/valid/', mask_dir)\n",
|
|||
|
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
|
|||
|
"test_set = NO2Dataset('./2022data/new_train_2021/test/', mask_dir)\n",
|
|||
|
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 可视化特定特征的函数\n",
|
|||
|
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
|||
|
" plt.figure(figsize=(12, 6))\n",
|
|||
|
" plt.subplot(1, 3, 1)\n",
|
|||
|
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Input\")\n",
|
|||
|
" plt.subplot(1, 3, 2)\n",
|
|||
|
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Masked\")\n",
|
|||
|
" plt.subplot(1, 3, 3)\n",
|
|||
|
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
|||
|
" plt.title(title + \" Recovery\")\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"id": "645114e8-65a4-4867-b3fe-23395288e855",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Conv(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
|
|||
|
" super(Conv, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ConvBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
|
|||
|
" bias=False):\n",
|
|||
|
" super(ConvBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
|||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
|
|||
|
" norm_layer(out_channels),\n",
|
|||
|
" nn.ReLU()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "31ecf247-e98b-4977-a145-782914a042bd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SeparableBNReLU(nn.Sequential):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
|
|||
|
" super(SeparableBNReLU, self).__init__(\n",
|
|||
|
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
|
|||
|
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
|
|||
|
" # 分离卷积,仅调整空间信息\n",
|
|||
|
" norm_layer(in_channels), # 对输入通道进行归一化\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
|
|||
|
" nn.ReLU6()\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class ResidualBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
|
|||
|
" super(ResidualBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
|||
|
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
|||
|
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
|
|||
|
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
|
|||
|
"\n",
|
|||
|
" # 如果输入和输出通道不一致,进行降采样操作\n",
|
|||
|
" self.downsample = downsample\n",
|
|||
|
" if in_channels != out_channels or stride != 1:\n",
|
|||
|
" self.downsample = nn.Sequential(\n",
|
|||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
|
|||
|
" nn.BatchNorm2d(out_channels)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" identity = x\n",
|
|||
|
" if self.downsample is not None:\n",
|
|||
|
" identity = self.downsample(x)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv1(x)\n",
|
|||
|
" out = self.bn1(out)\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
"\n",
|
|||
|
" out = self.conv2(out)\n",
|
|||
|
" out = self.bn2(out)\n",
|
|||
|
"\n",
|
|||
|
" out += identity\n",
|
|||
|
" out = self.relu(out)\n",
|
|||
|
" return out\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class Mlp(nn.Module):\n",
|
|||
|
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" out_features = out_features or in_features\n",
|
|||
|
" hidden_features = hidden_features or in_features\n",
|
|||
|
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
|||
|
"\n",
|
|||
|
" self.act = act_layer()\n",
|
|||
|
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
|||
|
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.fc1(x)\n",
|
|||
|
" x = self.act(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" x = self.fc2(x)\n",
|
|||
|
" x = self.drop(x)\n",
|
|||
|
" return x"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class MultiHeadAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
|
|||
|
" super(MultiHeadAttentionBlock, self).__init__()\n",
|
|||
|
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
|
|||
|
" self.norm = nn.LayerNorm(embed_dim)\n",
|
|||
|
" self.dropout = nn.Dropout(dropout)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
|
|||
|
" B, C, H, W = x.shape\n",
|
|||
|
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
|
|||
|
"\n",
|
|||
|
" # Apply multihead attention\n",
|
|||
|
" attn_output, _ = self.attention(x, x, x)\n",
|
|||
|
"\n",
|
|||
|
" # Apply normalization and dropout\n",
|
|||
|
" attn_output = self.norm(attn_output)\n",
|
|||
|
" attn_output = self.dropout(attn_output)\n",
|
|||
|
"\n",
|
|||
|
" # Reshape back to (B, C, H, W)\n",
|
|||
|
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
|
|||
|
"\n",
|
|||
|
" return attn_output"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SpatialAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(SpatialAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x): #(B, 64, H, W)\n",
|
|||
|
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
|
|||
|
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
|
|||
|
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
|
|||
|
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
|
|||
|
" return x * out #(B, C, H, W)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class DecoderAttentionBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels):\n",
|
|||
|
" super(DecoderAttentionBlock, self).__init__()\n",
|
|||
|
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
|
|||
|
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
|
|||
|
" self.spatial_attention = SpatialAttentionBlock()\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" # 通道注意力\n",
|
|||
|
" b, c, h, w = x.size()\n",
|
|||
|
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
|
|||
|
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
|
|||
|
"\n",
|
|||
|
" avg_out = self.conv1(avg_pool)\n",
|
|||
|
" max_out = self.conv1(max_pool)\n",
|
|||
|
"\n",
|
|||
|
" out = avg_out + max_out\n",
|
|||
|
" out = torch.sigmoid(self.conv2(out))\n",
|
|||
|
"\n",
|
|||
|
" # 添加空间注意力\n",
|
|||
|
" out = x * out\n",
|
|||
|
" out = self.spatial_attention(out)\n",
|
|||
|
" return out"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class SEBlock(nn.Module):\n",
|
|||
|
" def __init__(self, in_channels, reduced_dim):\n",
|
|||
|
" super(SEBlock, self).__init__()\n",
|
|||
|
" self.se = nn.Sequential(\n",
|
|||
|
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
|||
|
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
|||
|
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" return x * self.se(x)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "c9d176a8-bbf6-4043-ab82-1648a99d772a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def masked_mse_loss(preds, target, mask):\n",
|
|||
|
" loss = (preds - target) ** 2\n",
|
|||
|
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
|||
|
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
|||
|
" return loss"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 定义Masked Autoencoder模型\n",
|
|||
|
"class MaskedAutoencoder(nn.Module):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" super(MaskedAutoencoder, self).__init__()\n",
|
|||
|
" self.encoder = nn.Sequential(\n",
|
|||
|
" Conv(1, 32, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(32,32),\n",
|
|||
|
" \n",
|
|||
|
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" ResidualBlock(64,64),\n",
|
|||
|
" \n",
|
|||
|
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
|
|||
|
" \n",
|
|||
|
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
|
|||
|
" \n",
|
|||
|
" SEBlock(128, 128)\n",
|
|||
|
" )\n",
|
|||
|
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
|||
|
" self.decoder = nn.Sequential(\n",
|
|||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(32),\n",
|
|||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" DecoderAttentionBlock(16),\n",
|
|||
|
" nn.ReLU(),\n",
|
|||
|
" \n",
|
|||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
|||
|
" nn.Sigmoid()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" encoded = self.mlp(self.encoder(x))\n",
|
|||
|
" decoded = self.decoder(encoded)\n",
|
|||
|
" return decoded\n",
|
|||
|
"\n",
|
|||
|
"# 实例化模型、损失函数和优化器\n",
|
|||
|
"model = MaskedAutoencoder()\n",
|
|||
|
"criterion = nn.MSELoss()\n",
|
|||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 训练函数\n",
|
|||
|
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
|||
|
" model.train()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 评估函数\n",
|
|||
|
"def evaluate(model, device, data_loader, criterion):\n",
|
|||
|
" model.eval()\n",
|
|||
|
" running_loss = 0.0\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" if batch_idx == 8:\n",
|
|||
|
" rand_ind = np.random.randint(0, len(y))\n",
|
|||
|
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
|||
|
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
|||
|
" running_loss += loss.item()\n",
|
|||
|
" return running_loss / (batch_idx + 1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"cuda\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 数据准备\n",
|
|||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|||
|
"print(device)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
|
|||
|
" return F.conv2d(input, weight, bias, self.stride,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Epoch 1, Train Loss: 5.541112303206351, Val Loss: 0.8067406771030832\n",
|
|||
|
"Epoch 2, Train Loss: 0.33060450623344884, Val Loss: 0.14416445189333976\n",
|
|||
|
"Epoch 3, Train Loss: 0.14130625223251922, Val Loss: 0.07359389453492265\n",
|
|||
|
"Epoch 4, Train Loss: 0.10518970966866587, Val Loss: 0.054381779930058945\n",
|
|||
|
"Epoch 5, Train Loss: 0.09058622275218148, Val Loss: 0.0465024342720813\n",
|
|||
|
"Epoch 6, Train Loss: 0.08342431517521189, Val Loss: 0.042179942210303974\n",
|
|||
|
"Epoch 7, Train Loss: 0.0774571797686868, Val Loss: 0.03831239916542743\n",
|
|||
|
"Epoch 8, Train Loss: 0.0720803240385555, Val Loss: 0.03571732088606408\n",
|
|||
|
"Epoch 9, Train Loss: 0.06799363104247413, Val Loss: 0.03523132728135332\n",
|
|||
|
"Epoch 10, Train Loss: 0.06398953810597943, Val Loss: 0.03190892792128502\n",
|
|||
|
"Epoch 11, Train Loss: 0.06091008874914639, Val Loss: 0.030253113742838515\n",
|
|||
|
"Epoch 12, Train Loss: 0.058550740303718936, Val Loss: 0.03257738580887622\n",
|
|||
|
"Epoch 13, Train Loss: 0.05582124731094085, Val Loss: 0.027309948182169426\n",
|
|||
|
"Epoch 14, Train Loss: 0.05444160232369879, Val Loss: 0.03076436184346676\n",
|
|||
|
"Epoch 15, Train Loss: 0.053529950248896195, Val Loss: 0.026180010566369018\n",
|
|||
|
"Epoch 16, Train Loss: 0.05092262584375421, Val Loss: 0.02586523879398691\n",
|
|||
|
"Epoch 17, Train Loss: 0.05036925500297265, Val Loss: 0.026086220715908295\n",
|
|||
|
"Epoch 18, Train Loss: 0.04870546900922746, Val Loss: 0.025190358426659665\n",
|
|||
|
"Epoch 19, Train Loss: 0.04829096387533312, Val Loss: 0.024286496195387332\n",
|
|||
|
"Epoch 20, Train Loss: 0.047801207734552105, Val Loss: 0.024341319628218387\n",
|
|||
|
"Epoch 21, Train Loss: 0.0463638727533958, Val Loss: 0.023777516439874122\n",
|
|||
|
"Epoch 22, Train Loss: 0.04561143496505103, Val Loss: 0.02462407554242205\n",
|
|||
|
"Epoch 23, Train Loss: 0.04455085273469444, Val Loss: 0.02330890230517438\n",
|
|||
|
"Epoch 24, Train Loss: 0.04402760396489, Val Loss: 0.023676151687160453\n",
|
|||
|
"Epoch 25, Train Loss: 0.04317896270294808, Val Loss: 0.02370590161769948\n",
|
|||
|
"Epoch 26, Train Loss: 0.042474492900842764, Val Loss: 0.027188481287436284\n",
|
|||
|
"Epoch 27, Train Loss: 0.0410688633324474, Val Loss: 0.022131468387360267\n",
|
|||
|
"Epoch 28, Train Loss: 0.04015502951775504, Val Loss: 0.021191479004126913\n",
|
|||
|
"Epoch 29, Train Loss: 0.039912018183190213, Val Loss: 0.02161621072507919\n",
|
|||
|
"Epoch 30, Train Loss: 0.039861640131930685, Val Loss: 0.02177569658515301\n",
|
|||
|
"Epoch 31, Train Loss: 0.03960100152038016, Val Loss: 0.02101492334870582\n",
|
|||
|
"Epoch 32, Train Loss: 0.03872588457083632, Val Loss: 0.020859015748855916\n",
|
|||
|
"Epoch 33, Train Loss: 0.038754463954045706, Val Loss: 0.02414150171457453\n",
|
|||
|
"Epoch 34, Train Loss: 0.03809461849233394, Val Loss: 0.019819263804783212\n",
|
|||
|
"Epoch 35, Train Loss: 0.03751421304952606, Val Loss: 0.021835624696092404\n",
|
|||
|
"Epoch 36, Train Loss: 0.03734014398118915, Val Loss: 0.022214002510968674\n",
|
|||
|
"Epoch 37, Train Loss: 0.03706552038260442, Val Loss: 0.019966345795608582\n",
|
|||
|
"Epoch 38, Train Loss: 0.036659476251113376, Val Loss: 0.019636615555971227\n",
|
|||
|
"Epoch 39, Train Loss: 0.036727246869586214, Val Loss: 0.01936723065978669\n",
|
|||
|
"Epoch 40, Train Loss: 0.03633688813333666, Val Loss: 0.020286126339689216\n",
|
|||
|
"Epoch 41, Train Loss: 0.035810444339186745, Val Loss: 0.019208339934653425\n",
|
|||
|
"Epoch 42, Train Loss: 0.03550744545714694, Val Loss: 0.01972645398308622\n",
|
|||
|
"Epoch 43, Train Loss: 0.035464368694651444, Val Loss: 0.019827014984602622\n",
|
|||
|
"Epoch 44, Train Loss: 0.03506896948678128, Val Loss: 0.02004316175713184\n",
|
|||
|
"Epoch 45, Train Loss: 0.03495936513298732, Val Loss: 0.019192911129682622\n",
|
|||
|
"Epoch 46, Train Loss: 0.03483127771841038, Val Loss: 0.018953541115401908\n",
|
|||
|
"Epoch 47, Train Loss: 0.03463402198171545, Val Loss: 0.018771914527454275\n",
|
|||
|
"Epoch 48, Train Loss: 0.03408609382302712, Val Loss: 0.018758975068463923\n",
|
|||
|
"Epoch 49, Train Loss: 0.03452993459054502, Val Loss: 0.018336334998937363\n",
|
|||
|
"Epoch 50, Train Loss: 0.034099031547441594, Val Loss: 0.019093293062549956\n",
|
|||
|
"Epoch 51, Train Loss: 0.03445967665947644, Val Loss: 0.018671645683811067\n",
|
|||
|
"Epoch 52, Train Loss: 0.03385696139263544, Val Loss: 0.017988349291238378\n",
|
|||
|
"Epoch 53, Train Loss: 0.03406877570117997, Val Loss: 0.018068110510865425\n",
|
|||
|
"Epoch 54, Train Loss: 0.03348344178721968, Val Loss: 0.018683044398401644\n",
|
|||
|
"Epoch 55, Train Loss: 0.033462831668094196, Val Loss: 0.01905706045316889\n",
|
|||
|
"Epoch 56, Train Loss: 0.033128469962637686, Val Loss: 0.01867042989172834\n",
|
|||
|
"Epoch 57, Train Loss: 0.0332745431941607, Val Loss: 0.019846445504338183\n",
|
|||
|
"Epoch 58, Train Loss: 0.03308211129081812, Val Loss: 0.01826892614840193\n",
|
|||
|
"Epoch 59, Train Loss: 0.03278694228766415, Val Loss: 0.022516568488580115\n",
|
|||
|
"Epoch 60, Train Loss: 0.03246014836659122, Val Loss: 0.01806999350640368\n",
|
|||
|
"Epoch 61, Train Loss: 0.0331528295534814, Val Loss: 0.01772232149588935\n",
|
|||
|
"Epoch 62, Train Loss: 0.03278059815674757, Val Loss: 0.01812060377461479\n",
|
|||
|
"Epoch 63, Train Loss: 0.032278176842141994, Val Loss: 0.01805540711242468\n",
|
|||
|
"Epoch 64, Train Loss: 0.03201383460521874, Val Loss: 0.018378542062449963\n",
|
|||
|
"Epoch 65, Train Loss: 0.03193402631005003, Val Loss: 0.017855498166952994\n",
|
|||
|
"Epoch 66, Train Loss: 0.03141010671326545, Val Loss: 0.01813684691219254\n",
|
|||
|
"Epoch 67, Train Loss: 0.03162443816969528, Val Loss: 0.017312405214823308\n",
|
|||
|
"Epoch 68, Train Loss: 0.03134946423997569, Val Loss: 0.017035803282038964\n",
|
|||
|
"Epoch 69, Train Loss: 0.030821436257884565, Val Loss: 0.017176391457782148\n",
|
|||
|
"Epoch 70, Train Loss: 0.030857550524241103, Val Loss: 0.01778144468652441\n",
|
|||
|
"Epoch 71, Train Loss: 0.03145846045935927, Val Loss: 0.017036813350909567\n",
|
|||
|
"Epoch 72, Train Loss: 0.03082356479425522, Val Loss: 0.01754499076211706\n",
|
|||
|
"Epoch 73, Train Loss: 0.03057446662929997, Val Loss: 0.016873343847692013\n",
|
|||
|
"Epoch 74, Train Loss: 0.030142722530482793, Val Loss: 0.017114325763380275\n",
|
|||
|
"Epoch 75, Train Loss: 0.0297475472960764, Val Loss: 0.017896422284080626\n",
|
|||
|
"Epoch 76, Train Loss: 0.02986417829462912, Val Loss: 0.016979403338058197\n",
|
|||
|
"Epoch 77, Train Loss: 0.030155790255440722, Val Loss: 0.016632370690399027\n",
|
|||
|
"Epoch 78, Train Loss: 0.02987812078698019, Val Loss: 0.017218250702036187\n",
|
|||
|
"Epoch 79, Train Loss: 0.02965712085761855, Val Loss: 0.016456886016307994\n",
|
|||
|
"Epoch 80, Train Loss: 0.029867385275068537, Val Loss: 0.016108868465303107\n",
|
|||
|
"Epoch 81, Train Loss: 0.029616706633726054, Val Loss: 0.016850862830401735\n",
|
|||
|
"Epoch 82, Train Loss: 0.02933939000190535, Val Loss: 0.017380977188177566\n",
|
|||
|
"Epoch 83, Train Loss: 0.028856007063591024, Val Loss: 0.016677292380878266\n",
|
|||
|
"Epoch 84, Train Loss: 0.029245234613793088, Val Loss: 0.016243027404267738\n",
|
|||
|
"Epoch 85, Train Loss: 0.029124773610218438, Val Loss: 0.016707272605693088\n",
|
|||
|
"Epoch 86, Train Loss: 0.02889745979731941, Val Loss: 0.01667517395888237\n",
|
|||
|
"Epoch 87, Train Loss: 0.028780636237522143, Val Loss: 0.015974930111081042\n",
|
|||
|
"Epoch 88, Train Loss: 0.0290858921784479, Val Loss: 0.01647984809143112\n",
|
|||
|
"Epoch 89, Train Loss: 0.028605496862513125, Val Loss: 0.015814711419033244\n",
|
|||
|
"Epoch 90, Train Loss: 0.02866147620092451, Val Loss: 0.01892404787321674\n",
|
|||
|
"Epoch 91, Train Loss: 0.028418820038174107, Val Loss: 0.01616615823846548\n",
|
|||
|
"Epoch 92, Train Loss: 0.028970944983637437, Val Loss: 0.015930495700462066\n",
|
|||
|
"Epoch 93, Train Loss: 0.02812033420796767, Val Loss: 0.015577566691060016\n",
|
|||
|
"Epoch 94, Train Loss: 0.027900781900042276, Val Loss: 0.016411741838810293\n",
|
|||
|
"Epoch 95, Train Loss: 0.028156488249215756, Val Loss: 0.015642933785281282\n",
|
|||
|
"Epoch 96, Train Loss: 0.027669002046495413, Val Loss: 0.01564073005810063\n",
|
|||
|
"Epoch 97, Train Loss: 0.02797757544084988, Val Loss: 0.01616466465465566\n",
|
|||
|
"Epoch 98, Train Loss: 0.027837259815813517, Val Loss: 0.01699387704200567\n",
|
|||
|
"Epoch 99, Train Loss: 0.02773604567291814, Val Loss: 0.015504092572534338\n",
|
|||
|
"Epoch 100, Train Loss: 0.02741758727020746, Val Loss: 0.015247883136443634\n",
|
|||
|
"Epoch 101, Train Loss: 0.02707562789562705, Val Loss: 0.015558899360451293\n",
|
|||
|
"Epoch 102, Train Loss: 0.027159787612832578, Val Loss: 0.015182257392146487\n",
|
|||
|
"Epoch 103, Train Loss: 0.027029822105239625, Val Loss: 0.014660503893615083\n",
|
|||
|
"Epoch 104, Train Loss: 0.02699657593878497, Val Loss: 0.016841756120482658\n",
|
|||
|
"Epoch 105, Train Loss: 0.026641362756051144, Val Loss: 0.015178967544690091\n",
|
|||
|
"Epoch 106, Train Loss: 0.026524744587222385, Val Loss: 0.015554199926555157\n",
|
|||
|
"Epoch 107, Train Loss: 0.026474817848289083, Val Loss: 0.015399079710403656\n",
|
|||
|
"Epoch 108, Train Loss: 0.02636850485990269, Val Loss: 0.014777421396463476\n",
|
|||
|
"Epoch 109, Train Loss: 0.02637453050322413, Val Loss: 0.015275213094626336\n",
|
|||
|
"Epoch 110, Train Loss: 0.02607358055282659, Val Loss: 0.016890957614684357\n",
|
|||
|
"Epoch 111, Train Loss: 0.026133586770709285, Val Loss: 0.015139183485286032\n",
|
|||
|
"Epoch 112, Train Loss: 0.02617257334302924, Val Loss: 0.014704703016484038\n",
|
|||
|
"Epoch 113, Train Loss: 0.026084138217840926, Val Loss: 0.014918764835183925\n",
|
|||
|
"Epoch 114, Train Loss: 0.025832627078512777, Val Loss: 0.01494563212420078\n",
|
|||
|
"Epoch 115, Train Loss: 0.02605823659307837, Val Loss: 0.014487974504207043\n",
|
|||
|
"Epoch 116, Train Loss: 0.025865597622936103, Val Loss: 0.014469134779845147\n",
|
|||
|
"Epoch 117, Train Loss: 0.025718001264166693, Val Loss: 0.013978753100208779\n",
|
|||
|
"Epoch 118, Train Loss: 0.02561279770624233, Val Loss: 0.01455160214545879\n",
|
|||
|
"Epoch 119, Train Loss: 0.025601031165295295, Val Loss: 0.015720585519646075\n",
|
|||
|
"Epoch 120, Train Loss: 0.025754293742806685, Val Loss: 0.013814986822135906\n",
|
|||
|
"Epoch 121, Train Loss: 0.02534578327408231, Val Loss: 0.014853738644655714\n",
|
|||
|
"Epoch 122, Train Loss: 0.02561174121006752, Val Loss: 0.014788057021004088\n",
|
|||
|
"Epoch 123, Train Loss: 0.02533768888859622, Val Loss: 0.014425865988782111\n",
|
|||
|
"Epoch 124, Train Loss: 0.025395122024293847, Val Loss: 0.014166925221364549\n",
|
|||
|
"Epoch 125, Train Loss: 0.025411863934940996, Val Loss: 0.014836331670905681\n",
|
|||
|
"Epoch 126, Train Loss: 0.025214647187420048, Val Loss: 0.01417682920285362\n",
|
|||
|
"Epoch 127, Train Loss: 0.024879908288079025, Val Loss: 0.014164314981787763\n",
|
|||
|
"Epoch 128, Train Loss: 0.02494473186126501, Val Loss: 0.014208773448270685\n",
|
|||
|
"Epoch 129, Train Loss: 0.02468084254381755, Val Loss: 0.013683844337913585\n",
|
|||
|
"Epoch 130, Train Loss: 0.0248352900521066, Val Loss: 0.014833704508999561\n",
|
|||
|
"Epoch 131, Train Loss: 0.024615347561231404, Val Loss: 0.016790931608448637\n",
|
|||
|
"Epoch 132, Train Loss: 0.024628470901806445, Val Loss: 0.013669065913145846\n",
|
|||
|
"Epoch 133, Train Loss: 0.024401855987433486, Val Loss: 0.014544136485362307\n",
|
|||
|
"Epoch 134, Train Loss: 0.02425686465054311, Val Loss: 0.014493834742523254\n",
|
|||
|
"Epoch 135, Train Loss: 0.02475559137552801, Val Loss: 0.013708425725394107\n",
|
|||
|
"Epoch 136, Train Loss: 0.024078373256026818, Val Loss: 0.014549214828838693\n",
|
|||
|
"Epoch 137, Train Loss: 0.024223965633891325, Val Loss: 0.013578887454214249\n",
|
|||
|
"Epoch 138, Train Loss: 0.024396276563010383, Val Loss: 0.013251736344016612\n",
|
|||
|
"Epoch 139, Train Loss: 0.024004749286161586, Val Loss: 0.013333805103568321\n",
|
|||
|
"Epoch 140, Train Loss: 0.02389194364700697, Val Loss: 0.014107016430414737\n",
|
|||
|
"Epoch 141, Train Loss: 0.023637132873005923, Val Loss: 0.013322851898029764\n",
|
|||
|
"Epoch 142, Train Loss: 0.023719912169605582, Val Loss: 0.014070579683051464\n",
|
|||
|
"Epoch 143, Train Loss: 0.02377868151418579, Val Loss: 0.013563529806251222\n",
|
|||
|
"Epoch 144, Train Loss: 0.02362075615619312, Val Loss: 0.014379620492616867\n",
|
|||
|
"Epoch 145, Train Loss: 0.023822628134713236, Val Loss: 0.01308334250240884\n",
|
|||
|
"Epoch 146, Train Loss: 0.02378806389406719, Val Loss: 0.013488665500536878\n",
|
|||
|
"Epoch 147, Train Loss: 0.023415484050821767, Val Loss: 0.01323556466067725\n",
|
|||
|
"Epoch 148, Train Loss: 0.023618425456889434, Val Loss: 0.013999837430867744\n",
|
|||
|
"Epoch 149, Train Loss: 0.023620333203875563, Val Loss: 0.013482759567968388\n",
|
|||
|
"Epoch 150, Train Loss: 0.02325812268969232, Val Loss: 0.012960578988682716\n",
|
|||
|
"Test Loss: 0.028638894522660656\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = model.to(device)\n",
|
|||
|
"\n",
|
|||
|
"num_epochs = 150\n",
|
|||
|
"train_losses = list()\n",
|
|||
|
"val_losses = list()\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
|
|||
|
" train_losses.append(train_loss)\n",
|
|||
|
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
|||
|
" val_losses.append(val_loss)\n",
|
|||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
|||
|
"\n",
|
|||
|
"# 测试模型\n",
|
|||
|
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
|||
|
"print(f'Test Loss: {test_loss}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"<matplotlib.legend.Legend at 0x7f70bbbcb3d0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXEklEQVR4nO3deXhU1f0G8Hf2yWSZbGSSQCBh30KCLBFxJxLUUnEFSmWpS12w2tSNtoBKNYDKjyoIra0LVgVtxVoXFCMgYlgDIjsikABZSEIyyUwy6/39cTITBgIzE5LcLO/neeZJcufOnXMCMq/nfs85CkmSJBARERG1YUq5G0BERETkDwMLERERtXkMLERERNTmMbAQERFRm8fAQkRERG0eAwsRERG1eQwsRERE1OYxsBAREVGbp5a7Ac3B7Xbj1KlTCA8Ph0KhkLs5REREFABJklBdXY3ExEQolRcfQ+kQgeXUqVNISkqSuxlERETUBIWFhejWrdtFz+kQgSU8PByA6HBERITMrSEiIqJAmM1mJCUleT/HL6ZDBBbPbaCIiAgGFiIionYmkHIOFt0SERFRm8fAQkRERG0eAwsRERG1eR2ihoWIiDoeSZLgdDrhcrnkbgpdApVKBbVafcnLjjCwEBFRm2O321FUVASr1Sp3U6gZGAwGJCQkQKvVNvkaDCxERNSmuN1uHD16FCqVComJidBqtVwUtJ2SJAl2ux2nT5/G0aNH0adPH78LxF0IAwsREbUpdrsdbrcbSUlJMBgMcjeHLlFISAg0Gg2OHz8Ou90OvV7fpOuw6JaIiNqkpv6fOLU9zfFnyb8NRERE1OYxsBAREVGbx8BCRETUBiUnJ2Px4sXNcq3169dDoVCgsrKyWa4nBxbdEhERNZNrr70W6enpzRI0tm3bhtDQ0EtvVAfBwHIRDpcbOZ8fgFuSMOum/tCpVXI3iYiI2jFJkuByuaBW+//47dKlSyu0qP3gLaGLkCTgjU1H8db3x2BzuuVuDhFRpyVJEqx2pywPSZICauP06dOxYcMG/PWvf4VCoYBCocBbb70FhUKBL774AsOGDYNOp8N3332HI0eO4JZbboHJZEJYWBhGjBiBr7/+2ud6594SUigU+Mc//oFbb70VBoMBffr0wSeffNLk3+l//vMfDBo0CDqdDsnJyXj55Zd9nn/ttdfQp08f6PV6mEwm3HHHHd7n/v3vfyM1NRUhISGIiYlBZmYmLBZLk9sSCI6wXIRK2bBQkcsV2F9YIiJqfrUOFwbO+VKW9973XBYMWv8fl3/9619x6NAhDB48GM899xwAYO/evQCAp59+Gi+99BJ69uyJqKgoFBYW4qabbsLzzz8PnU6HFStWYPz48Th48CC6d+9+wfd49tlnsXDhQrz44ot49dVXMWXKFBw/fhzR0dFB9WnHjh2466678Mwzz2DixIn4/vvv8dBDDyEmJgbTp0/H9u3b8bvf/Q7vvPMOrrjiClRUVGDjxo0AgKKiIkyePBkLFy7ErbfeiurqamzcuDHgYNdUDCwXcVZegdPNwEJERBdmNBqh1WphMBgQHx8PADhw4AAA4LnnnsMNN9zgPTc6OhppaWnen+fNm4fVq1fjk08+wcyZMy/4HtOnT8fkyZMBAC+88AJeeeUVbN26FePGjQuqrYsWLcKYMWMwe/ZsAEDfvn2xb98+vPjii5g+fToKCgoQGhqKX/ziFwgPD0ePHj0wdOhQACKwOJ1O3HbbbejRowcAIDU1Naj3bwoGlotQKBRQKxVwuiW4Wzg5EhHRhYVoVNj3XJZs732phg8f7vNzTU0NnnnmGXz22WfeAFBbW4uCgoKLXmfIkCHe70NDQxEREYHS0tKg27N//37ccsstPsdGjx6NxYsXw+Vy4YYbbkCPHj3Qs2dPjBs3DuPGjfPeikpLS8OYMWOQmpqKrKwsjB07FnfccQeioqKCbkcwWMPih+e2EEdYiIjko1AoYNCqZXk0xz5G5872efzxx7F69Wq88MIL2LhxI3bt2oXU1FTY7faLXkej0Zz3e3G7m7/GMjw8HPn5+Xj//feRkJCAOXPmIC0tDZWVlVCpVFi7di2++OILDBw4EK+++ir69euHo0ePNns7zsbA4oe6PrCwhoWIiPzRarVwuVx+z9u0aROmT5+OW2+9FampqYiPj8exY8davoH1BgwYgE2bNp3Xpr59+0KlEiNKarUamZmZWLhwIXbv3o1jx47hm2++ASCC0ujRo/Hss89i586d0Gq1WL16dYu2mbeE/GgYYeEsISIiurjk5GRs2bIFx44dQ1hY2AVHP/r06YOPPvoI48ePh0KhwOzZs1tkpORC/vCHP2DEiBGYN28eJk6ciLy8PCxZsgSvvfYaAODTTz/Fzz//jKuvvhpRUVH4/PPP4Xa70a9fP2zZsgW5ubkYO3Ys4uLisGXLFpw+fRoDBgxo0TZzhMUPtUr8ily8JURERH48/vjjUKlUGDhwILp06XLBmpRFixYhKioKV1xxBcaPH4+srCxcdtllrdbOyy67DB988AFWrlyJwYMHY86cOXjuuecwffp0AEBkZCQ++ugjXH/99RgwYACWL1+O999/H4MGDUJERAS+/fZb3HTTTejbty/+/Oc/4+WXX8aNN97Yom1WSC09D6kVmM1mGI1GVFVVISIiolmvPeL5r3G62oYvHr0KAxKa99pERHS+uro6HD16FCkpKdDr9XI3h5rBhf5Mg/n85giLH94aFo6wEBERyYaBxQ+lgrOEiIiobXvggQcQFhbW6OOBBx6Qu3nNgkW3fqhVnhEWFt0SEVHb9Nxzz+Hxxx9v9LnmLpWQCwOLH95ZQpzWTEREbVRcXBzi4uLkbkaL4i0hP7w1LO2/NpmIiKjdYmDxQ6XktGYiIiK5MbD4oebS/ERERLJjYPFDxaX5iYiIZMfA4gdHWIiIiOTHwOKHigvHERFRK0lOTsbixYsDOlehUODjjz9u0fa0JQwsfnjWYeHmh0RERPJhYPHDs9ItR1iIiIjkw8DiB2tYiIjaAEkC7BZ5HgGuw/X3v/8diYmJcJ8zIn/LLbfgN7/5DY4cOYJbbrkFJpMJYWFhGDFiBL7++utm+xX9+OOPuP766xESEoKYmBjcf//9qKmp8T6/fv16jBw5EqGhoYiMjMTo0aNx/PhxAMAPP/yA6667DuHh4YiIiMCwYcOwffv2Zmtbc+BKt3541mFxM7AQEcnHYQVeSJTnvf94CtCG+j3tzjvvxCOPPIJ169ZhzJgxAICKigqsWbMGn3/+OWpqanDTTTfh+eefh06nw4oVKzB+/HgcPHgQ3bt3v6QmWiwWZGVlYdSoUdi2bRtKS0tx7733YubMmXjrrbfgdDoxYcIE3HfffXj//fdht9uxdetWKOrvIkyZMgVDhw7FsmXLoFKpsGvXLmg0mktqU3NjYPGDIyxERBSIqKgo3HjjjXjvvfe8geXf//43YmNjcd1110GpVCItLc17/rx587B69Wp88sknmDlz5iW993vvvYe6ujqsWLECoaEiXC1ZsgTjx4/HggULoNFoUFVVhV/84hfo1asXAGDAgAHe1xcUFOCJJ55A//79AQB9+vS5pPa0BAYWP1Qq1rAQEclOYxAjHXK9d4CmTJmC++67D6+99hp0Oh3effddTJo0CUqlEjU1NXjmmWfw2WefoaioCE6nE7W1tSgoKLjkJu7fvx9paWnesAIAo0ePhtvtxsGDB3H11Vdj+vTpyMrKwg033IDMzEzcddddSEhIAABkZ2fj3nvvxTvvvIPMzEzceeed3mDTVrCGxQ+OsBARtQEKhbgtI8ej/rZJIMaPHw9JkvDZZ5+hsLAQGzduxJQpUwAAjz/+OFavXo0XXngBGzduxK5du5Camgq73d5SvzUfb775JvLy8nDFFVdg1apV6Nu3LzZv3gwAeOaZZ7B3717cfPPN+OabbzBw4ECsXr26VdoVKAYWPxrWYeG0ZiIiuji9Xo/bbrsN7777Lt5//33069cPl112GQBg06ZNmD59Om699VakpqYiPj4ex44da5b3HTBgAH744QdYLBbvsU2bNkGpVKJfv37eY0OHDsW
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"tr_ind = list(range(len(train_losses)))\n",
|
|||
|
"val_ind = list(range(len(val_losses)))\n",
|
|||
|
"plt.plot(train_losses[1:], label='train_loss')\n",
|
|||
|
"plt.plot(val_losses[1:], label='val_loss')\n",
|
|||
|
"plt.legend(loc='best')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "efc96935-bbe0-4ca9-b11a-931cdcfc3bed",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def cal_ioa(y_true, y_pred):\n",
|
|||
|
" # 计算平均值\n",
|
|||
|
" mean_observed = np.mean(y_true)\n",
|
|||
|
" mean_predicted = np.mean(y_pred)\n",
|
|||
|
"\n",
|
|||
|
" # 计算IoA\n",
|
|||
|
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
|||
|
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
|||
|
" IoA = 1 - (numerator / denominator)\n",
|
|||
|
"\n",
|
|||
|
" return IoA"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
|||
|
" data_label = data_label[mask_rev==1]\n",
|
|||
|
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
|||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"id": "589e6d80-228d-4e8a-968a-e7477c5e0e24",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" <td>76.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.839756</td>\n",
|
|||
|
" <td>3.116629</td>\n",
|
|||
|
" <td>0.182020</td>\n",
|
|||
|
" <td>0.887434</td>\n",
|
|||
|
" <td>0.984924</td>\n",
|
|||
|
" <td>0.942480</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>0.141333</td>\n",
|
|||
|
" <td>0.218296</td>\n",
|
|||
|
" <td>0.009663</td>\n",
|
|||
|
" <td>0.010754</td>\n",
|
|||
|
" <td>0.001488</td>\n",
|
|||
|
" <td>0.005507</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>1.518803</td>\n",
|
|||
|
" <td>2.497997</td>\n",
|
|||
|
" <td>0.160656</td>\n",
|
|||
|
" <td>0.862124</td>\n",
|
|||
|
" <td>0.981807</td>\n",
|
|||
|
" <td>0.930738</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>1.725160</td>\n",
|
|||
|
" <td>2.954316</td>\n",
|
|||
|
" <td>0.175752</td>\n",
|
|||
|
" <td>0.880888</td>\n",
|
|||
|
" <td>0.984067</td>\n",
|
|||
|
" <td>0.939023</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.819375</td>\n",
|
|||
|
" <td>3.084015</td>\n",
|
|||
|
" <td>0.180615</td>\n",
|
|||
|
" <td>0.887742</td>\n",
|
|||
|
" <td>0.984906</td>\n",
|
|||
|
" <td>0.942333</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>1.973080</td>\n",
|
|||
|
" <td>3.293117</td>\n",
|
|||
|
" <td>0.189375</td>\n",
|
|||
|
" <td>0.895522</td>\n",
|
|||
|
" <td>0.986135</td>\n",
|
|||
|
" <td>0.946845</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>2.156417</td>\n",
|
|||
|
" <td>3.538444</td>\n",
|
|||
|
" <td>0.205928</td>\n",
|
|||
|
" <td>0.909979</td>\n",
|
|||
|
" <td>0.988135</td>\n",
|
|||
|
" <td>0.954369</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa r\n",
|
|||
|
"count 76.000000 76.000000 76.000000 76.000000 76.000000 76.000000\n",
|
|||
|
"mean 1.839756 3.116629 0.182020 0.887434 0.984924 0.942480\n",
|
|||
|
"std 0.141333 0.218296 0.009663 0.010754 0.001488 0.005507\n",
|
|||
|
"min 1.518803 2.497997 0.160656 0.862124 0.981807 0.930738\n",
|
|||
|
"25% 1.725160 2.954316 0.175752 0.880888 0.984067 0.939023\n",
|
|||
|
"50% 1.819375 3.084015 0.180615 0.887742 0.984906 0.942333\n",
|
|||
|
"75% 1.973080 3.293117 0.189375 0.895522 0.986135 0.946845\n",
|
|||
|
"max 2.156417 3.538444 0.205928 0.909979 0.988135 0.954369"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 45,
|
|||
|
"id": "6278442c-3ecb-4e92-b901-0f0e0e43d8af",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"eva_list = list()\n",
|
|||
|
"device = 'cpu'\n",
|
|||
|
"model = model.to(device)\n",
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
|||
|
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
|||
|
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
|||
|
" reconstructed = model(X)\n",
|
|||
|
" rev_data = y * max_pixel_value\n",
|
|||
|
" rev_recon = reconstructed * max_pixel_value\n",
|
|||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
|||
|
" for i, sample in enumerate(rev_data):\n",
|
|||
|
" used_mask = mask_rev[i]\n",
|
|||
|
" data_label = sample[0] * used_mask\n",
|
|||
|
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
|||
|
" data_label = data_label[used_mask==1]\n",
|
|||
|
" recon_no2 = recon_no2[used_mask==1]\n",
|
|||
|
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
|||
|
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
|||
|
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
|||
|
" r2 = r2_score(data_label, recon_no2)\n",
|
|||
|
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
|||
|
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
|||
|
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 46,
|
|||
|
"id": "3d095141-79e2-4f4f-b31f-54fed1996781",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>mae</th>\n",
|
|||
|
" <th>rmse</th>\n",
|
|||
|
" <th>mape</th>\n",
|
|||
|
" <th>r2</th>\n",
|
|||
|
" <th>ioa</th>\n",
|
|||
|
" <th>r</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" <td>4839.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <td>1.833400</td>\n",
|
|||
|
" <td>2.618025</td>\n",
|
|||
|
" <td>0.181476</td>\n",
|
|||
|
" <td>0.631467</td>\n",
|
|||
|
" <td>0.937835</td>\n",
|
|||
|
" <td>0.813992</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <td>1.185956</td>\n",
|
|||
|
" <td>1.683402</td>\n",
|
|||
|
" <td>0.071764</td>\n",
|
|||
|
" <td>0.260356</td>\n",
|
|||
|
" <td>0.053726</td>\n",
|
|||
|
" <td>0.123230</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <td>0.248986</td>\n",
|
|||
|
" <td>0.337919</td>\n",
|
|||
|
" <td>0.075559</td>\n",
|
|||
|
" <td>-3.769637</td>\n",
|
|||
|
" <td>0.103451</td>\n",
|
|||
|
" <td>0.020267</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <td>0.843516</td>\n",
|
|||
|
" <td>1.179310</td>\n",
|
|||
|
" <td>0.138988</td>\n",
|
|||
|
" <td>0.537750</td>\n",
|
|||
|
" <td>0.924173</td>\n",
|
|||
|
" <td>0.762655</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <td>1.335556</td>\n",
|
|||
|
" <td>1.939605</td>\n",
|
|||
|
" <td>0.165166</td>\n",
|
|||
|
" <td>0.682016</td>\n",
|
|||
|
" <td>0.951186</td>\n",
|
|||
|
" <td>0.841513</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <td>2.759837</td>\n",
|
|||
|
" <td>3.977746</td>\n",
|
|||
|
" <td>0.202966</td>\n",
|
|||
|
" <td>0.792008</td>\n",
|
|||
|
" <td>0.970115</td>\n",
|
|||
|
" <td>0.898809</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" <td>9.474609</td>\n",
|
|||
|
" <td>10.988250</td>\n",
|
|||
|
" <td>1.344091</td>\n",
|
|||
|
" <td>0.978264</td>\n",
|
|||
|
" <td>0.997251</td>\n",
|
|||
|
" <td>0.989095</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" mae rmse mape r2 ioa \\\n",
|
|||
|
"count 4839.000000 4839.000000 4839.000000 4839.000000 4839.000000 \n",
|
|||
|
"mean 1.833400 2.618025 0.181476 0.631467 0.937835 \n",
|
|||
|
"std 1.185956 1.683402 0.071764 0.260356 0.053726 \n",
|
|||
|
"min 0.248986 0.337919 0.075559 -3.769637 0.103451 \n",
|
|||
|
"25% 0.843516 1.179310 0.138988 0.537750 0.924173 \n",
|
|||
|
"50% 1.335556 1.939605 0.165166 0.682016 0.951186 \n",
|
|||
|
"75% 2.759837 3.977746 0.202966 0.792008 0.970115 \n",
|
|||
|
"max 9.474609 10.988250 1.344091 0.978264 0.997251 \n",
|
|||
|
"\n",
|
|||
|
" r \n",
|
|||
|
"count 4839.000000 \n",
|
|||
|
"mean 0.813992 \n",
|
|||
|
"std 0.123230 \n",
|
|||
|
"min 0.020267 \n",
|
|||
|
"25% 0.762655 \n",
|
|||
|
"50% 0.841513 \n",
|
|||
|
"75% 0.898809 \n",
|
|||
|
"max 0.989095 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "8eb4d33a-8d03-418d-bb50-f34eef4e4bf5",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|