MAE_ATMO/torch_MAE_1d_final_20_2021....

1170 lines
77 KiB
Plaintext
Raw Permalink Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 35,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b8a8cedd-536d-4a48-a1af-7c40489ef0f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f7419087810>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(42)\n",
"torch.random.manual_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c28cc123-71be-47ff-b78f-3a4d5592df39",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum pixel value in the dataset: 92.64960479736328\n"
]
}
],
"source": [
"# 定义函数来找到最大值\n",
"def find_max_pixel_value(image_dir):\n",
" max_pixel_value = 0.0\n",
" for filename in os.listdir(image_dir):\n",
" if filename.endswith('.npy'):\n",
" image_path = os.path.join(image_dir, filename)\n",
" image = np.load(image_path).astype(np.float32)\n",
" max_pixel_value = max(max_pixel_value, image.max())\n",
" return max_pixel_value\n",
"\n",
"# 计算图像数据中的最大像素值\n",
"image_dir = './2022data/new_train_2021/train/' \n",
"max_pixel_value = find_max_pixel_value(image_dir)\n",
"\n",
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dbfe80ce-4394-449c-a9a4-22ed15b2b8f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.expand_dims(np.load(image_path).astype(np.float32), axis=2) / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './2022data/new_train_2021/train/'\n",
"mask_dir = './2022data/new_train_2021/mask/20/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "41da7319-9795-441d-bde8-8cf390365099",
"metadata": {},
"outputs": [],
"source": [
"train_set = NO2Dataset(image_dir, mask_dir)\n",
"train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)\n",
"val_set = NO2Dataset('./2022data/new_train_2021/valid/', mask_dir)\n",
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
"test_set = NO2Dataset('./2022data/new_train_2021/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "645114e8-65a4-4867-b3fe-23395288e855",
"metadata": {},
"outputs": [],
"source": [
"class Conv(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
" super(Conv, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
" bias=False):\n",
" super(ConvBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
" norm_layer(out_channels),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "31ecf247-e98b-4977-a145-782914a042bd",
"metadata": {},
"outputs": [],
"source": [
"class SeparableBNReLU(nn.Sequential):\n",
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
" super(SeparableBNReLU, self).__init__(\n",
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
" # 分离卷积,仅调整空间信息\n",
" norm_layer(in_channels), # 对输入通道进行归一化\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
" nn.ReLU6()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
"\n",
" # 如果输入和输出通道不一致,进行降采样操作\n",
" self.downsample = downsample\n",
" if in_channels != out_channels or stride != 1:\n",
" self.downsample = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
" nn.BatchNorm2d(out_channels)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttentionBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
" super(MultiHeadAttentionBlock, self).__init__()\n",
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
" B, C, H, W = x.shape\n",
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
"\n",
" # Apply multihead attention\n",
" attn_output, _ = self.attention(x, x, x)\n",
"\n",
" # Apply normalization and dropout\n",
" attn_output = self.norm(attn_output)\n",
" attn_output = self.dropout(attn_output)\n",
"\n",
" # Reshape back to (B, C, H, W)\n",
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
"\n",
" return attn_output"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
"metadata": {},
"outputs": [],
"source": [
"class SpatialAttentionBlock(nn.Module):\n",
" def __init__(self):\n",
" super(SpatialAttentionBlock, self).__init__()\n",
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
"\n",
" def forward(self, x): #(B, 64, H, W)\n",
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
" return x * out #(B, C, H, W)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
"metadata": {},
"outputs": [],
"source": [
"class DecoderAttentionBlock(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(DecoderAttentionBlock, self).__init__()\n",
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
" self.spatial_attention = SpatialAttentionBlock()\n",
"\n",
" def forward(self, x):\n",
" # 通道注意力\n",
" b, c, h, w = x.size()\n",
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
"\n",
" avg_out = self.conv1(avg_pool)\n",
" max_out = self.conv1(max_pool)\n",
"\n",
" out = avg_out + max_out\n",
" out = torch.sigmoid(self.conv2(out))\n",
"\n",
" # 添加空间注意力\n",
" out = x * out\n",
" out = self.spatial_attention(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c9d176a8-bbf6-4043-ab82-1648a99d772a",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" Conv(1, 32, kernel_size=3, stride=2),\n",
" \n",
" nn.ReLU(),\n",
" \n",
" SEBlock(32,32),\n",
" \n",
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
" \n",
" ResidualBlock(64,64),\n",
" \n",
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
" \n",
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
" \n",
" SEBlock(128, 128)\n",
" )\n",
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(32),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" \n",
" DecoderAttentionBlock(16),\n",
" nn.ReLU(),\n",
" \n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.mlp(self.encoder(x))\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(X)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" reconstructed = model(X)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(y))\n",
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "296ba6bd-2239-4948-b278-7edcb29bfd14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 5.541112303206351, Val Loss: 0.8067406771030832\n",
"Epoch 2, Train Loss: 0.33060450623344884, Val Loss: 0.14416445189333976\n",
"Epoch 3, Train Loss: 0.14130625223251922, Val Loss: 0.07359389453492265\n",
"Epoch 4, Train Loss: 0.10518970966866587, Val Loss: 0.054381779930058945\n",
"Epoch 5, Train Loss: 0.09058622275218148, Val Loss: 0.0465024342720813\n",
"Epoch 6, Train Loss: 0.08342431517521189, Val Loss: 0.042179942210303974\n",
"Epoch 7, Train Loss: 0.0774571797686868, Val Loss: 0.03831239916542743\n",
"Epoch 8, Train Loss: 0.0720803240385555, Val Loss: 0.03571732088606408\n",
"Epoch 9, Train Loss: 0.06799363104247413, Val Loss: 0.03523132728135332\n",
"Epoch 10, Train Loss: 0.06398953810597943, Val Loss: 0.03190892792128502\n",
"Epoch 11, Train Loss: 0.06091008874914639, Val Loss: 0.030253113742838515\n",
"Epoch 12, Train Loss: 0.058550740303718936, Val Loss: 0.03257738580887622\n",
"Epoch 13, Train Loss: 0.05582124731094085, Val Loss: 0.027309948182169426\n",
"Epoch 14, Train Loss: 0.05444160232369879, Val Loss: 0.03076436184346676\n",
"Epoch 15, Train Loss: 0.053529950248896195, Val Loss: 0.026180010566369018\n",
"Epoch 16, Train Loss: 0.05092262584375421, Val Loss: 0.02586523879398691\n",
"Epoch 17, Train Loss: 0.05036925500297265, Val Loss: 0.026086220715908295\n",
"Epoch 18, Train Loss: 0.04870546900922746, Val Loss: 0.025190358426659665\n",
"Epoch 19, Train Loss: 0.04829096387533312, Val Loss: 0.024286496195387332\n",
"Epoch 20, Train Loss: 0.047801207734552105, Val Loss: 0.024341319628218387\n",
"Epoch 21, Train Loss: 0.0463638727533958, Val Loss: 0.023777516439874122\n",
"Epoch 22, Train Loss: 0.04561143496505103, Val Loss: 0.02462407554242205\n",
"Epoch 23, Train Loss: 0.04455085273469444, Val Loss: 0.02330890230517438\n",
"Epoch 24, Train Loss: 0.04402760396489, Val Loss: 0.023676151687160453\n",
"Epoch 25, Train Loss: 0.04317896270294808, Val Loss: 0.02370590161769948\n",
"Epoch 26, Train Loss: 0.042474492900842764, Val Loss: 0.027188481287436284\n",
"Epoch 27, Train Loss: 0.0410688633324474, Val Loss: 0.022131468387360267\n",
"Epoch 28, Train Loss: 0.04015502951775504, Val Loss: 0.021191479004126913\n",
"Epoch 29, Train Loss: 0.039912018183190213, Val Loss: 0.02161621072507919\n",
"Epoch 30, Train Loss: 0.039861640131930685, Val Loss: 0.02177569658515301\n",
"Epoch 31, Train Loss: 0.03960100152038016, Val Loss: 0.02101492334870582\n",
"Epoch 32, Train Loss: 0.03872588457083632, Val Loss: 0.020859015748855916\n",
"Epoch 33, Train Loss: 0.038754463954045706, Val Loss: 0.02414150171457453\n",
"Epoch 34, Train Loss: 0.03809461849233394, Val Loss: 0.019819263804783212\n",
"Epoch 35, Train Loss: 0.03751421304952606, Val Loss: 0.021835624696092404\n",
"Epoch 36, Train Loss: 0.03734014398118915, Val Loss: 0.022214002510968674\n",
"Epoch 37, Train Loss: 0.03706552038260442, Val Loss: 0.019966345795608582\n",
"Epoch 38, Train Loss: 0.036659476251113376, Val Loss: 0.019636615555971227\n",
"Epoch 39, Train Loss: 0.036727246869586214, Val Loss: 0.01936723065978669\n",
"Epoch 40, Train Loss: 0.03633688813333666, Val Loss: 0.020286126339689216\n",
"Epoch 41, Train Loss: 0.035810444339186745, Val Loss: 0.019208339934653425\n",
"Epoch 42, Train Loss: 0.03550744545714694, Val Loss: 0.01972645398308622\n",
"Epoch 43, Train Loss: 0.035464368694651444, Val Loss: 0.019827014984602622\n",
"Epoch 44, Train Loss: 0.03506896948678128, Val Loss: 0.02004316175713184\n",
"Epoch 45, Train Loss: 0.03495936513298732, Val Loss: 0.019192911129682622\n",
"Epoch 46, Train Loss: 0.03483127771841038, Val Loss: 0.018953541115401908\n",
"Epoch 47, Train Loss: 0.03463402198171545, Val Loss: 0.018771914527454275\n",
"Epoch 48, Train Loss: 0.03408609382302712, Val Loss: 0.018758975068463923\n",
"Epoch 49, Train Loss: 0.03452993459054502, Val Loss: 0.018336334998937363\n",
"Epoch 50, Train Loss: 0.034099031547441594, Val Loss: 0.019093293062549956\n",
"Epoch 51, Train Loss: 0.03445967665947644, Val Loss: 0.018671645683811067\n",
"Epoch 52, Train Loss: 0.03385696139263544, Val Loss: 0.017988349291238378\n",
"Epoch 53, Train Loss: 0.03406877570117997, Val Loss: 0.018068110510865425\n",
"Epoch 54, Train Loss: 0.03348344178721968, Val Loss: 0.018683044398401644\n",
"Epoch 55, Train Loss: 0.033462831668094196, Val Loss: 0.01905706045316889\n",
"Epoch 56, Train Loss: 0.033128469962637686, Val Loss: 0.01867042989172834\n",
"Epoch 57, Train Loss: 0.0332745431941607, Val Loss: 0.019846445504338183\n",
"Epoch 58, Train Loss: 0.03308211129081812, Val Loss: 0.01826892614840193\n",
"Epoch 59, Train Loss: 0.03278694228766415, Val Loss: 0.022516568488580115\n",
"Epoch 60, Train Loss: 0.03246014836659122, Val Loss: 0.01806999350640368\n",
"Epoch 61, Train Loss: 0.0331528295534814, Val Loss: 0.01772232149588935\n",
"Epoch 62, Train Loss: 0.03278059815674757, Val Loss: 0.01812060377461479\n",
"Epoch 63, Train Loss: 0.032278176842141994, Val Loss: 0.01805540711242468\n",
"Epoch 64, Train Loss: 0.03201383460521874, Val Loss: 0.018378542062449963\n",
"Epoch 65, Train Loss: 0.03193402631005003, Val Loss: 0.017855498166952994\n",
"Epoch 66, Train Loss: 0.03141010671326545, Val Loss: 0.01813684691219254\n",
"Epoch 67, Train Loss: 0.03162443816969528, Val Loss: 0.017312405214823308\n",
"Epoch 68, Train Loss: 0.03134946423997569, Val Loss: 0.017035803282038964\n",
"Epoch 69, Train Loss: 0.030821436257884565, Val Loss: 0.017176391457782148\n",
"Epoch 70, Train Loss: 0.030857550524241103, Val Loss: 0.01778144468652441\n",
"Epoch 71, Train Loss: 0.03145846045935927, Val Loss: 0.017036813350909567\n",
"Epoch 72, Train Loss: 0.03082356479425522, Val Loss: 0.01754499076211706\n",
"Epoch 73, Train Loss: 0.03057446662929997, Val Loss: 0.016873343847692013\n",
"Epoch 74, Train Loss: 0.030142722530482793, Val Loss: 0.017114325763380275\n",
"Epoch 75, Train Loss: 0.0297475472960764, Val Loss: 0.017896422284080626\n",
"Epoch 76, Train Loss: 0.02986417829462912, Val Loss: 0.016979403338058197\n",
"Epoch 77, Train Loss: 0.030155790255440722, Val Loss: 0.016632370690399027\n",
"Epoch 78, Train Loss: 0.02987812078698019, Val Loss: 0.017218250702036187\n",
"Epoch 79, Train Loss: 0.02965712085761855, Val Loss: 0.016456886016307994\n",
"Epoch 80, Train Loss: 0.029867385275068537, Val Loss: 0.016108868465303107\n",
"Epoch 81, Train Loss: 0.029616706633726054, Val Loss: 0.016850862830401735\n",
"Epoch 82, Train Loss: 0.02933939000190535, Val Loss: 0.017380977188177566\n",
"Epoch 83, Train Loss: 0.028856007063591024, Val Loss: 0.016677292380878266\n",
"Epoch 84, Train Loss: 0.029245234613793088, Val Loss: 0.016243027404267738\n",
"Epoch 85, Train Loss: 0.029124773610218438, Val Loss: 0.016707272605693088\n",
"Epoch 86, Train Loss: 0.02889745979731941, Val Loss: 0.01667517395888237\n",
"Epoch 87, Train Loss: 0.028780636237522143, Val Loss: 0.015974930111081042\n",
"Epoch 88, Train Loss: 0.0290858921784479, Val Loss: 0.01647984809143112\n",
"Epoch 89, Train Loss: 0.028605496862513125, Val Loss: 0.015814711419033244\n",
"Epoch 90, Train Loss: 0.02866147620092451, Val Loss: 0.01892404787321674\n",
"Epoch 91, Train Loss: 0.028418820038174107, Val Loss: 0.01616615823846548\n",
"Epoch 92, Train Loss: 0.028970944983637437, Val Loss: 0.015930495700462066\n",
"Epoch 93, Train Loss: 0.02812033420796767, Val Loss: 0.015577566691060016\n",
"Epoch 94, Train Loss: 0.027900781900042276, Val Loss: 0.016411741838810293\n",
"Epoch 95, Train Loss: 0.028156488249215756, Val Loss: 0.015642933785281282\n",
"Epoch 96, Train Loss: 0.027669002046495413, Val Loss: 0.01564073005810063\n",
"Epoch 97, Train Loss: 0.02797757544084988, Val Loss: 0.01616466465465566\n",
"Epoch 98, Train Loss: 0.027837259815813517, Val Loss: 0.01699387704200567\n",
"Epoch 99, Train Loss: 0.02773604567291814, Val Loss: 0.015504092572534338\n",
"Epoch 100, Train Loss: 0.02741758727020746, Val Loss: 0.015247883136443634\n",
"Epoch 101, Train Loss: 0.02707562789562705, Val Loss: 0.015558899360451293\n",
"Epoch 102, Train Loss: 0.027159787612832578, Val Loss: 0.015182257392146487\n",
"Epoch 103, Train Loss: 0.027029822105239625, Val Loss: 0.014660503893615083\n",
"Epoch 104, Train Loss: 0.02699657593878497, Val Loss: 0.016841756120482658\n",
"Epoch 105, Train Loss: 0.026641362756051144, Val Loss: 0.015178967544690091\n",
"Epoch 106, Train Loss: 0.026524744587222385, Val Loss: 0.015554199926555157\n",
"Epoch 107, Train Loss: 0.026474817848289083, Val Loss: 0.015399079710403656\n",
"Epoch 108, Train Loss: 0.02636850485990269, Val Loss: 0.014777421396463476\n",
"Epoch 109, Train Loss: 0.02637453050322413, Val Loss: 0.015275213094626336\n",
"Epoch 110, Train Loss: 0.02607358055282659, Val Loss: 0.016890957614684357\n",
"Epoch 111, Train Loss: 0.026133586770709285, Val Loss: 0.015139183485286032\n",
"Epoch 112, Train Loss: 0.02617257334302924, Val Loss: 0.014704703016484038\n",
"Epoch 113, Train Loss: 0.026084138217840926, Val Loss: 0.014918764835183925\n",
"Epoch 114, Train Loss: 0.025832627078512777, Val Loss: 0.01494563212420078\n",
"Epoch 115, Train Loss: 0.02605823659307837, Val Loss: 0.014487974504207043\n",
"Epoch 116, Train Loss: 0.025865597622936103, Val Loss: 0.014469134779845147\n",
"Epoch 117, Train Loss: 0.025718001264166693, Val Loss: 0.013978753100208779\n",
"Epoch 118, Train Loss: 0.02561279770624233, Val Loss: 0.01455160214545879\n",
"Epoch 119, Train Loss: 0.025601031165295295, Val Loss: 0.015720585519646075\n",
"Epoch 120, Train Loss: 0.025754293742806685, Val Loss: 0.013814986822135906\n",
"Epoch 121, Train Loss: 0.02534578327408231, Val Loss: 0.014853738644655714\n",
"Epoch 122, Train Loss: 0.02561174121006752, Val Loss: 0.014788057021004088\n",
"Epoch 123, Train Loss: 0.02533768888859622, Val Loss: 0.014425865988782111\n",
"Epoch 124, Train Loss: 0.025395122024293847, Val Loss: 0.014166925221364549\n",
"Epoch 125, Train Loss: 0.025411863934940996, Val Loss: 0.014836331670905681\n",
"Epoch 126, Train Loss: 0.025214647187420048, Val Loss: 0.01417682920285362\n",
"Epoch 127, Train Loss: 0.024879908288079025, Val Loss: 0.014164314981787763\n",
"Epoch 128, Train Loss: 0.02494473186126501, Val Loss: 0.014208773448270685\n",
"Epoch 129, Train Loss: 0.02468084254381755, Val Loss: 0.013683844337913585\n",
"Epoch 130, Train Loss: 0.0248352900521066, Val Loss: 0.014833704508999561\n",
"Epoch 131, Train Loss: 0.024615347561231404, Val Loss: 0.016790931608448637\n",
"Epoch 132, Train Loss: 0.024628470901806445, Val Loss: 0.013669065913145846\n",
"Epoch 133, Train Loss: 0.024401855987433486, Val Loss: 0.014544136485362307\n",
"Epoch 134, Train Loss: 0.02425686465054311, Val Loss: 0.014493834742523254\n",
"Epoch 135, Train Loss: 0.02475559137552801, Val Loss: 0.013708425725394107\n",
"Epoch 136, Train Loss: 0.024078373256026818, Val Loss: 0.014549214828838693\n",
"Epoch 137, Train Loss: 0.024223965633891325, Val Loss: 0.013578887454214249\n",
"Epoch 138, Train Loss: 0.024396276563010383, Val Loss: 0.013251736344016612\n",
"Epoch 139, Train Loss: 0.024004749286161586, Val Loss: 0.013333805103568321\n",
"Epoch 140, Train Loss: 0.02389194364700697, Val Loss: 0.014107016430414737\n",
"Epoch 141, Train Loss: 0.023637132873005923, Val Loss: 0.013322851898029764\n",
"Epoch 142, Train Loss: 0.023719912169605582, Val Loss: 0.014070579683051464\n",
"Epoch 143, Train Loss: 0.02377868151418579, Val Loss: 0.013563529806251222\n",
"Epoch 144, Train Loss: 0.02362075615619312, Val Loss: 0.014379620492616867\n",
"Epoch 145, Train Loss: 0.023822628134713236, Val Loss: 0.01308334250240884\n",
"Epoch 146, Train Loss: 0.02378806389406719, Val Loss: 0.013488665500536878\n",
"Epoch 147, Train Loss: 0.023415484050821767, Val Loss: 0.01323556466067725\n",
"Epoch 148, Train Loss: 0.023618425456889434, Val Loss: 0.013999837430867744\n",
"Epoch 149, Train Loss: 0.023620333203875563, Val Loss: 0.013482759567968388\n",
"Epoch 150, Train Loss: 0.02325812268969232, Val Loss: 0.012960578988682716\n",
"Test Loss: 0.028638894522660656\n"
]
}
],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 150\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f70bbbcb3d0>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXEklEQVR4nO3deXhU1f0G8Hf2yWSZbGSSQCBh30KCLBFxJxLUUnEFSmWpS12w2tSNtoBKNYDKjyoIra0LVgVtxVoXFCMgYlgDIjsikABZSEIyyUwy6/39cTITBgIzE5LcLO/neeZJcufOnXMCMq/nfs85CkmSJBARERG1YUq5G0BERETkDwMLERERtXkMLERERNTmMbAQERFRm8fAQkRERG0eAwsRERG1eQwsRERE1OYxsBAREVGbp5a7Ac3B7Xbj1KlTCA8Ph0KhkLs5REREFABJklBdXY3ExEQolRcfQ+kQgeXUqVNISkqSuxlERETUBIWFhejWrdtFz+kQgSU8PByA6HBERITMrSEiIqJAmM1mJCUleT/HL6ZDBBbPbaCIiAgGFiIionYmkHIOFt0SERFRm8fAQkRERG0eAwsRERG1eR2ihoWIiDoeSZLgdDrhcrnkbgpdApVKBbVafcnLjjCwEBFRm2O321FUVASr1Sp3U6gZGAwGJCQkQKvVNvkaDCxERNSmuN1uHD16FCqVComJidBqtVwUtJ2SJAl2ux2nT5/G0aNH0adPH78LxF0IAwsREbUpdrsdbrcbSUlJMBgMcjeHLlFISAg0Gg2OHz8Ou90OvV7fpOuw6JaIiNqkpv6fOLU9zfFnyb8NRERE1OYxsBAREVGbx8BCRETUBiUnJ2Px4sXNcq3169dDoVCgsrKyWa4nBxbdEhERNZNrr70W6enpzRI0tm3bhtDQ0EtvVAfBwHIRDpcbOZ8fgFuSMOum/tCpVXI3iYiI2jFJkuByuaBW+//47dKlSyu0qP3gLaGLkCTgjU1H8db3x2BzuuVuDhFRpyVJEqx2pywPSZICauP06dOxYcMG/PWvf4VCoYBCocBbb70FhUKBL774AsOGDYNOp8N3332HI0eO4JZbboHJZEJYWBhGjBiBr7/+2ud6594SUigU+Mc//oFbb70VBoMBffr0wSeffNLk3+l//vMfDBo0CDqdDsnJyXj55Zd9nn/ttdfQp08f6PV6mEwm3HHHHd7n/v3vfyM1NRUhISGIiYlBZmYmLBZLk9sSCI6wXIRK2bBQkcsV2F9YIiJqfrUOFwbO+VKW9973XBYMWv8fl3/9619x6NAhDB48GM899xwAYO/evQCAp59+Gi+99BJ69uyJqKgoFBYW4qabbsLzzz8PnU6HFStWYPz48Th48CC6d+9+wfd49tlnsXDhQrz44ot49dVXMWXKFBw/fhzR0dFB9WnHjh2466678Mwzz2DixIn4/vvv8dBDDyEmJgbTp0/H9u3b8bvf/Q7vvPMOrrjiClRUVGDjxo0AgKKiIkyePBkLFy7ErbfeiurqamzcuDHgYNdUDCwXcVZegdPNwEJERBdmNBqh1WphMBgQHx8PADhw4AAA4LnnnsMNN9zgPTc6OhppaWnen+fNm4fVq1fjk08+wcyZMy/4HtOnT8fkyZMBAC+88AJeeeUVbN26FePGjQuqrYsWLcKYMWMwe/ZsAEDfvn2xb98+vPjii5g+fToKCgoQGhqKX/ziFwgPD0ePHj0wdOhQACKwOJ1O3HbbbejRowcAIDU1Naj3bwoGlotQKBRQKxVwuiW4Wzg5EhHRhYVoVNj3XJZs732phg8f7vNzTU0NnnnmGXz22WfeAFBbW4uCgoKLXmfIkCHe70NDQxEREYHS0tKg27N//37ccsstPsdGjx6NxYsXw+Vy4YYbbkCPHj3Qs2dPjBs3DuPGjfPeikpLS8OYMWOQmpqKrKwsjB07FnfccQeioqKCbkcwWMPih+e2EEdYiIjko1AoYNCqZXk0xz5G5872efzxx7F69Wq88MIL2LhxI3bt2oXU1FTY7faLXkej0Zz3e3G7m7/GMjw8HPn5+Xj//feRkJCAOXPmIC0tDZWVlVCpVFi7di2++OILDBw4EK+++ir69euHo0ePNns7zsbA4oe6PrCwhoWIiPzRarVwuVx+z9u0aROmT5+OW2+9FampqYiPj8exY8davoH1BgwYgE2bNp3Xpr59+0KlEiNKarUamZmZWLhwIXbv3o1jx47hm2++ASCC0ujRo/Hss89i586d0Gq1WL16dYu2mbeE/GgYYeEsISIiurjk5GRs2bIFx44dQ1hY2AVHP/r06YOPPvoI48ePh0KhwOzZs1tkpORC/vCHP2DEiBGYN28eJk6ciLy8PCxZsgSvvfYaAODTTz/Fzz//jKuvvhpRUVH4/PPP4Xa70a9fP2zZsgW5ubkYO3Ys4uLisGXLFpw+fRoDBgxo0TZzhMUPtUr8ily8JURERH48/vjjUKlUGDhwILp06XLBmpRFixYhKioKV1xxBcaPH4+srCxcdtllrdbOyy67DB988AFWrlyJwYMHY86cOXjuuecwffp0AEBkZCQ++ugjXH/99RgwYACWL1+O999/H4MGDUJERAS+/fZb3HTTTejbty/+/Oc/4+WXX8aNN97Yom1WSC09D6kVmM1mGI1GVFVVISIiolmvPeL5r3G62oYvHr0KAxKa99pERHS+uro6HD16FCkpKdDr9XI3h5rBhf5Mg/n85giLH94aFo6wEBERyYaBxQ+lgrOEiIiobXvggQcQFhbW6OOBBx6Qu3nNgkW3fqhVnhEWFt0SEVHb9Nxzz+Hxxx9v9LnmLpWQCwOLH95ZQpzWTEREbVRcXBzi4uLkbkaL4i0hP7w1LO2/NpmIiKjdYmDxQ6XktGYiIiK5MbD4oebS/ERERLJjYPFDxaX5iYiIZMfA4gdHWIiIiOTHwOKHigvHERFRK0lOTsbixYsDOlehUODjjz9u0fa0JQwsfnjWYeHmh0RERPJhYPHDs9ItR1iIiIjkw8DiB2tYiIjaAEkC7BZ5HgGuw/X3v/8diYmJcJ8zIn/LLbfgN7/5DY4cOYJbbrkFJpMJYWFhGDFiBL7++utm+xX9+OOPuP766xESEoKYmBjcf//9qKmp8T6/fv16jBw5EqGhoYiMjMTo0aNx/PhxAMAPP/yA6667DuHh4YiIiMCwYcOwffv2Zmtbc+BKt3541mFxM7AQEcnHYQVeSJTnvf94CtCG+j3tzjvvxCOPPIJ169ZhzJgxAICKigqsWbMGn3/+OWpqanDTTTfh+eefh06nw4oVKzB+/HgcPHgQ3bt3v6QmWiwWZGVlYdSoUdi2bRtKS0tx7733YubMmXjrrbfgdDoxYcIE3HfffXj//fdht9uxdetWKOrvIkyZMgVDhw7FsmXLoFKpsGvXLmg0mktqU3NjYPGDIyxERBSIqKgo3HjjjXjvvfe8geXf//43YmNjcd1110GpVCItLc17/rx587B69Wp88sknmDlz5iW993vvvYe6ujqsWLECoaEiXC1ZsgTjx4/HggULoNFoUFVVhV/84hfo1asXAGDAgAHe1xcUFOCJJ55A//79AQB9+vS5pPa0BAYWP1Qq1rAQEclOYxAjHXK9d4CmTJmC++67D6+99hp0Oh3effddTJo0CUqlEjU1NXjmmWfw2WefoaioCE6nE7W1tSgoKLjkJu7fvx9paWnesAIAo0ePhtvtxsGDB3H11Vdj+vTpyMrKwg033IDMzEzcddddSEhIAABkZ2fj3nvvxTvvvIPMzEzceeed3mDTVrCGxQ+OsBARtQEKhbgtI8ej/rZJIMaPHw9JkvDZZ5+hsLAQGzduxJQpUwAAjz/+OFavXo0XXngBGzduxK5du5Camgq73d5SvzUfb775JvLy8nDFFVdg1apV6Nu3LzZv3gwAeOaZZ7B3717cfPPN+OabbzBw4ECsXr26VdoVKAYWPxrWYeG0ZiIiuji9Xo/bbrsN7777Lt5//33069cPl112GQBg06ZNmD59Om699VakpqYiPj4ex44da5b3HTBgAH744QdYLBbvsU2bNkGpVKJfv37eY0OHDsW
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses[1:], label='train_loss')\n",
"plt.plot(val_losses[1:], label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "efc96935-bbe0-4ca9-b11a-931cdcfc3bed",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "589e6d80-228d-4e8a-968a-e7477c5e0e24",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>76.000000</td>\n",
" <td>76.000000</td>\n",
" <td>76.000000</td>\n",
" <td>76.000000</td>\n",
" <td>76.000000</td>\n",
" <td>76.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.839756</td>\n",
" <td>3.116629</td>\n",
" <td>0.182020</td>\n",
" <td>0.887434</td>\n",
" <td>0.984924</td>\n",
" <td>0.942480</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.141333</td>\n",
" <td>0.218296</td>\n",
" <td>0.009663</td>\n",
" <td>0.010754</td>\n",
" <td>0.001488</td>\n",
" <td>0.005507</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.518803</td>\n",
" <td>2.497997</td>\n",
" <td>0.160656</td>\n",
" <td>0.862124</td>\n",
" <td>0.981807</td>\n",
" <td>0.930738</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.725160</td>\n",
" <td>2.954316</td>\n",
" <td>0.175752</td>\n",
" <td>0.880888</td>\n",
" <td>0.984067</td>\n",
" <td>0.939023</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.819375</td>\n",
" <td>3.084015</td>\n",
" <td>0.180615</td>\n",
" <td>0.887742</td>\n",
" <td>0.984906</td>\n",
" <td>0.942333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.973080</td>\n",
" <td>3.293117</td>\n",
" <td>0.189375</td>\n",
" <td>0.895522</td>\n",
" <td>0.986135</td>\n",
" <td>0.946845</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>2.156417</td>\n",
" <td>3.538444</td>\n",
" <td>0.205928</td>\n",
" <td>0.909979</td>\n",
" <td>0.988135</td>\n",
" <td>0.954369</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa r\n",
"count 76.000000 76.000000 76.000000 76.000000 76.000000 76.000000\n",
"mean 1.839756 3.116629 0.182020 0.887434 0.984924 0.942480\n",
"std 0.141333 0.218296 0.009663 0.010754 0.001488 0.005507\n",
"min 1.518803 2.497997 0.160656 0.862124 0.981807 0.930738\n",
"25% 1.725160 2.954316 0.175752 0.880888 0.984067 0.939023\n",
"50% 1.819375 3.084015 0.180615 0.887742 0.984906 0.942333\n",
"75% 1.973080 3.293117 0.189375 0.895522 0.986135 0.946845\n",
"max 2.156417 3.538444 0.205928 0.909979 0.988135 0.954369"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "6278442c-3ecb-4e92-b901-0f0e0e43d8af",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "3d095141-79e2-4f4f-b31f-54fed1996781",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4839.000000</td>\n",
" <td>4839.000000</td>\n",
" <td>4839.000000</td>\n",
" <td>4839.000000</td>\n",
" <td>4839.000000</td>\n",
" <td>4839.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.833400</td>\n",
" <td>2.618025</td>\n",
" <td>0.181476</td>\n",
" <td>0.631467</td>\n",
" <td>0.937835</td>\n",
" <td>0.813992</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>1.185956</td>\n",
" <td>1.683402</td>\n",
" <td>0.071764</td>\n",
" <td>0.260356</td>\n",
" <td>0.053726</td>\n",
" <td>0.123230</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.248986</td>\n",
" <td>0.337919</td>\n",
" <td>0.075559</td>\n",
" <td>-3.769637</td>\n",
" <td>0.103451</td>\n",
" <td>0.020267</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.843516</td>\n",
" <td>1.179310</td>\n",
" <td>0.138988</td>\n",
" <td>0.537750</td>\n",
" <td>0.924173</td>\n",
" <td>0.762655</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.335556</td>\n",
" <td>1.939605</td>\n",
" <td>0.165166</td>\n",
" <td>0.682016</td>\n",
" <td>0.951186</td>\n",
" <td>0.841513</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>2.759837</td>\n",
" <td>3.977746</td>\n",
" <td>0.202966</td>\n",
" <td>0.792008</td>\n",
" <td>0.970115</td>\n",
" <td>0.898809</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>9.474609</td>\n",
" <td>10.988250</td>\n",
" <td>1.344091</td>\n",
" <td>0.978264</td>\n",
" <td>0.997251</td>\n",
" <td>0.989095</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4839.000000 4839.000000 4839.000000 4839.000000 4839.000000 \n",
"mean 1.833400 2.618025 0.181476 0.631467 0.937835 \n",
"std 1.185956 1.683402 0.071764 0.260356 0.053726 \n",
"min 0.248986 0.337919 0.075559 -3.769637 0.103451 \n",
"25% 0.843516 1.179310 0.138988 0.537750 0.924173 \n",
"50% 1.335556 1.939605 0.165166 0.682016 0.951186 \n",
"75% 2.759837 3.977746 0.202966 0.792008 0.970115 \n",
"max 9.474609 10.988250 1.344091 0.978264 0.997251 \n",
"\n",
" r \n",
"count 4839.000000 \n",
"mean 0.813992 \n",
"std 0.123230 \n",
"min 0.020267 \n",
"25% 0.762655 \n",
"50% 0.841513 \n",
"75% 0.898809 \n",
"max 0.989095 "
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eb4d33a-8d03-418d-bb50-f34eef4e4bf5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}