896 lines
53 KiB
Plaintext
896 lines
53 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import torch.optim as optim\n",
|
||
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
||
"from PIL import Image\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import cv2"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "e0afbbc4-cd35-49f7-986f-2c0a6fff5ec1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<torch._C.Generator at 0x7f6d6be638f0>"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"np.random.seed(0)\n",
|
||
"torch.random.manual_seed(0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "95baeec7-508b-480c-b598-aecab7497a99",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Maximum pixel value in the dataset: 107.49169921875\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 定义函数来找到最大值\n",
|
||
"def find_max_pixel_value(image_dir):\n",
|
||
" max_pixel_value = 0.0\n",
|
||
" for filename in os.listdir(image_dir):\n",
|
||
" if filename.endswith('.npy'):\n",
|
||
" image_path = os.path.join(image_dir, filename)\n",
|
||
" image = np.load(image_path).astype(np.float32)\n",
|
||
" max_pixel_value = max(max_pixel_value, image[:, :, 0].max())\n",
|
||
" return max_pixel_value\n",
|
||
"\n",
|
||
"# 计算图像数据中的最大像素值\n",
|
||
"image_dir = './out_mat/96/train/' \n",
|
||
"max_pixel_value = find_max_pixel_value(image_dir)\n",
|
||
"\n",
|
||
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "9a8fe22d-5029-427f-bae8-01934a0d5c35",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"checkpoint before Generator is OK\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"class NO2Dataset(Dataset):\n",
|
||
" \n",
|
||
" def __init__(self, image_dir, mask_dir):\n",
|
||
" \n",
|
||
" self.image_dir = image_dir\n",
|
||
" self.mask_dir = mask_dir\n",
|
||
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
||
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" \n",
|
||
" return len(self.image_filenames)\n",
|
||
" \n",
|
||
" def __getitem__(self, idx):\n",
|
||
" \n",
|
||
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
||
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
||
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
||
"\n",
|
||
" # 加载图像数据 (.npy 文件)\n",
|
||
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 加载掩码数据 (.jpg 文件)\n",
|
||
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
||
"\n",
|
||
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
||
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
||
"\n",
|
||
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
||
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 应用掩码\n",
|
||
" masked_image = image.copy()\n",
|
||
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
||
"\n",
|
||
" # cGAN的输入和目标\n",
|
||
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
||
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
||
"\n",
|
||
" # 转换形状为 (channels, height, width)\n",
|
||
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||
"\n",
|
||
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
||
"\n",
|
||
"# 实例化数据集和数据加载器\n",
|
||
"image_dir = './out_mat/96/train/'\n",
|
||
"mask_dir = './out_mat/96/mask/20/'\n",
|
||
"\n",
|
||
"print(f\"checkpoint before Generator is OK\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "ddbc13ba-a0e8-477e-895e-371a78085bac",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = NO2Dataset(image_dir, mask_dir)\n",
|
||
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
||
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
|
||
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
|
||
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
|
||
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 可视化特定特征的函数\n",
|
||
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
||
" plt.figure(figsize=(12, 6))\n",
|
||
" plt.subplot(1, 3, 1)\n",
|
||
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
" plt.title(title + \" Input\")\n",
|
||
" plt.subplot(1, 3, 2)\n",
|
||
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
" plt.title(title + \" Masked\")\n",
|
||
" plt.subplot(1, 3, 3)\n",
|
||
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
" plt.title(title + \" Recovery\")\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"cuda\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 设置随机种子以确保结果的可重复性\n",
|
||
"torch.manual_seed(0)\n",
|
||
"np.random.seed(0)\n",
|
||
"\n",
|
||
"# 数据准备\n",
|
||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"print(device)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "f821d0a1-dfee-483e-b081-68c963bdb8a0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 定义Masked Autoencoder模型\n",
|
||
"class MaskedAutoencoderBase(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super(MaskedAutoencoderBase, self).__init__()\n",
|
||
" self.encoder = nn.Sequential(\n",
|
||
" nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
|
||
" nn.ReLU(),\n",
|
||
" )\n",
|
||
" self.decoder = nn.Sequential(\n",
|
||
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" encoded = self.encoder(x)\n",
|
||
" decoded = self.decoder(encoded)\n",
|
||
" return decoded"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "2dc47416-511e-4874-abaf-30dd912a0e7d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def masked_mse_loss(preds, target, mask):\n",
|
||
" loss = (preds - target) ** 2\n",
|
||
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
||
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
||
" return loss"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "2e77e837-071c-46d0-9779-80bb333db800",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 实例化模型、损失函数和优化器\n",
|
||
"model = MaskedAutoencoderBase()\n",
|
||
"criterion = nn.MSELoss()\n",
|
||
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 训练函数\n",
|
||
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
||
" model.train()\n",
|
||
" running_loss = 0.0\n",
|
||
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" optimizer.zero_grad()\n",
|
||
" reconstructed = model(X)\n",
|
||
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
||
" # loss = criterion(reconstructed, y)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
" running_loss += loss.item()\n",
|
||
" return running_loss / (batch_idx + 1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 评估函数\n",
|
||
"def evaluate(model, device, data_loader, criterion):\n",
|
||
" model.eval()\n",
|
||
" running_loss = 0.0\n",
|
||
" with torch.no_grad():\n",
|
||
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" reconstructed = model(X)\n",
|
||
" if batch_idx == 8:\n",
|
||
" rand_ind = np.random.randint(0, len(y))\n",
|
||
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
||
" loss = masked_mse_loss(reconstructed, y, mask)\n",
|
||
" # loss = criterion(reconstructed, y)\n",
|
||
" running_loss += loss.item()\n",
|
||
" return running_loss / (batch_idx + 1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1, Train Loss: 2.4377448216936233, Val Loss: 0.1723788405087457\n",
|
||
"Epoch 2, Train Loss: 0.09637997932197374, Val Loss: 0.07621741728551353\n",
|
||
"Epoch 3, Train Loss: 0.06397618102934657, Val Loss: 0.06195200451496822\n",
|
||
"Epoch 4, Train Loss: 0.052692288621974906, Val Loss: 0.052603201690449644\n",
|
||
"Epoch 5, Train Loss: 0.045533701719529036, Val Loss: 0.0462518873721806\n",
|
||
"Epoch 6, Train Loss: 0.040426999678095564, Val Loss: 0.04118765834996949\n",
|
||
"Epoch 7, Train Loss: 0.03643315702979787, Val Loss: 0.0370612932619319\n",
|
||
"Epoch 8, Train Loss: 0.03297993362074691, Val Loss: 0.0338741072189452\n",
|
||
"Epoch 9, Train Loss: 0.030229569176595177, Val Loss: 0.03180063916231269\n",
|
||
"Epoch 10, Train Loss: 0.028299911767600827, Val Loss: 0.03058352780097456\n",
|
||
"Epoch 11, Train Loss: 0.026935207724357337, Val Loss: 0.029766072282817826\n",
|
||
"Epoch 12, Train Loss: 0.026076676769618782, Val Loss: 0.028107319638800265\n",
|
||
"Epoch 13, Train Loss: 0.02534967841821139, Val Loss: 0.027272115475428637\n",
|
||
"Epoch 14, Train Loss: 0.024701394381349166, Val Loss: 0.02684043228292643\n",
|
||
"Epoch 15, Train Loss: 0.0240272392550011, Val Loss: 0.02594853615138068\n",
|
||
"Epoch 16, Train Loss: 0.0233813104438083, Val Loss: 0.025640942656726978\n",
|
||
"Epoch 17, Train Loss: 0.02310016915273438, Val Loss: 0.02571806650775582\n",
|
||
"Epoch 18, Train Loss: 0.022718923658792054, Val Loss: 0.024644668200122778\n",
|
||
"Epoch 19, Train Loss: 0.022323213453052576, Val Loss: 0.024273945435659208\n",
|
||
"Epoch 20, Train Loss: 0.02199719715685223, Val Loss: 0.02410240029332353\n",
|
||
"Epoch 21, Train Loss: 0.021530815467024535, Val Loss: 0.02380427871066243\n",
|
||
"Epoch 22, Train Loss: 0.021460241776262743, Val Loss: 0.0232450627346537\n",
|
||
"Epoch 23, Train Loss: 0.02090326771050977, Val Loss: 0.022885078564286232\n",
|
||
"Epoch 24, Train Loss: 0.020652044475363774, Val Loss: 0.022562191390724323\n",
|
||
"Epoch 25, Train Loss: 0.02051923798985387, Val Loss: 0.022203324724044373\n",
|
||
"Epoch 26, Train Loss: 0.020149177833767743, Val Loss: 0.022744494337421744\n",
|
||
"Epoch 27, Train Loss: 0.020068248640300268, Val Loss: 0.022425833088693333\n",
|
||
"Epoch 28, Train Loss: 0.019720358143529397, Val Loss: 0.02253118777341807\n",
|
||
"Epoch 29, Train Loss: 0.01939903690288084, Val Loss: 0.021765351378873213\n",
|
||
"Epoch 30, Train Loss: 0.01943497322989922, Val Loss: 0.021345259649540062\n",
|
||
"Epoch 31, Train Loss: 0.019241397384928458, Val Loss: 0.02124041018646155\n",
|
||
"Epoch 32, Train Loss: 0.01906546402464144, Val Loss: 0.021633521083797982\n",
|
||
"Epoch 33, Train Loss: 0.01884070100512302, Val Loss: 0.021043253979131357\n",
|
||
"Epoch 34, Train Loss: 0.01874133140855785, Val Loss: 0.02059999839472237\n",
|
||
"Epoch 35, Train Loss: 0.01853996916544851, Val Loss: 0.021178998303279947\n",
|
||
"Epoch 36, Train Loss: 0.018260161060412106, Val Loss: 0.020367807639178944\n",
|
||
"Epoch 37, Train Loss: 0.01830708983233956, Val Loss: 0.020017842692670536\n",
|
||
"Epoch 38, Train Loss: 0.018042967790675362, Val Loss: 0.020187884722071798\n",
|
||
"Epoch 39, Train Loss: 0.017922898732197056, Val Loss: 0.019615614786744118\n",
|
||
"Epoch 40, Train Loss: 0.017794321282236486, Val Loss: 0.019430582606191956\n",
|
||
"Epoch 41, Train Loss: 0.017688655022656517, Val Loss: 0.019477688401603875\n",
|
||
"Epoch 42, Train Loss: 0.017460078103512383, Val Loss: 0.018902005530448993\n",
|
||
"Epoch 43, Train Loss: 0.01727662416638441, Val Loss: 0.018832763184362382\n",
|
||
"Epoch 44, Train Loss: 0.017280888195599666, Val Loss: 0.019056980081124983\n",
|
||
"Epoch 45, Train Loss: 0.017114856775012312, Val Loss: 0.018604515495696174\n",
|
||
"Epoch 46, Train Loss: 0.016909640970858234, Val Loss: 0.018437264904157438\n",
|
||
"Epoch 47, Train Loss: 0.016691252999185946, Val Loss: 0.01889144025965413\n",
|
||
"Epoch 48, Train Loss: 0.016869753608079047, Val Loss: 0.018732781104965887\n",
|
||
"Epoch 49, Train Loss: 0.01653263871179243, Val Loss: 0.01850963812043418\n",
|
||
"Epoch 50, Train Loss: 0.01653244017520875, Val Loss: 0.0178856217344083\n",
|
||
"Epoch 51, Train Loss: 0.016499577624874823, Val Loss: 0.01781756227919415\n",
|
||
"Epoch 52, Train Loss: 0.016335643743249504, Val Loss: 0.01821571894323648\n",
|
||
"Epoch 53, Train Loss: 0.016375035212406415, Val Loss: 0.017511379168327176\n",
|
||
"Epoch 54, Train Loss: 0.016288986672428948, Val Loss: 0.017456448650849398\n",
|
||
"Epoch 55, Train Loss: 0.01623404509517137, Val Loss: 0.017827068525018978\n",
|
||
"Epoch 56, Train Loss: 0.016188283936615196, Val Loss: 0.017475027326883663\n",
|
||
"Epoch 57, Train Loss: 0.01605349867359588, Val Loss: 0.017256822728955033\n",
|
||
"Epoch 58, Train Loss: 0.015958637990610022, Val Loss: 0.017457437256712522\n",
|
||
"Epoch 59, Train Loss: 0.016034694237001774, Val Loss: 0.017437012713235705\n",
|
||
"Epoch 60, Train Loss: 0.0158486066956483, Val Loss: 0.017560158175096582\n",
|
||
"Epoch 61, Train Loss: 0.015632042563275286, Val Loss: 0.01692103194211846\n",
|
||
"Epoch 62, Train Loss: 0.015540152108608677, Val Loss: 0.01698271286632143\n",
|
||
"Epoch 63, Train Loss: 0.01545496231043025, Val Loss: 0.01699626362368242\n",
|
||
"Epoch 64, Train Loss: 0.015430795162488398, Val Loss: 0.01687317063559347\n",
|
||
"Epoch 65, Train Loss: 0.015489797350732191, Val Loss: 0.017046043955123248\n",
|
||
"Epoch 66, Train Loss: 0.015236956011682179, Val Loss: 0.0172197060214717\n",
|
||
"Epoch 67, Train Loss: 0.015348140916755895, Val Loss: 0.016508253249548265\n",
|
||
"Epoch 68, Train Loss: 0.015228347097519055, Val Loss: 0.016413471842212462\n",
|
||
"Epoch 69, Train Loss: 0.01516882229025997, Val Loss: 0.01686259738600521\n",
|
||
"Epoch 70, Train Loss: 0.015173258315593574, Val Loss: 0.01757873013726811\n",
|
||
"Epoch 71, Train Loss: 0.015156847716678986, Val Loss: 0.016662339123883353\n",
|
||
"Epoch 72, Train Loss: 0.015105586064507088, Val Loss: 0.016890839868183457\n",
|
||
"Epoch 73, Train Loss: 0.014925161955887051, Val Loss: 0.015931842709655194\n",
|
||
"Epoch 74, Train Loss: 0.014886363126497947, Val Loss: 0.016006485308840204\n",
|
||
"Epoch 75, Train Loss: 0.015015289608531735, Val Loss: 0.015968994154080526\n",
|
||
"Epoch 76, Train Loss: 0.014806462892968403, Val Loss: 0.015919692327838336\n",
|
||
"Epoch 77, Train Loss: 0.014728168116962653, Val Loss: 0.015852669684855797\n",
|
||
"Epoch 78, Train Loss: 0.014845167781319914, Val Loss: 0.016079049404543726\n",
|
||
"Epoch 79, Train Loss: 0.014719554133998435, Val Loss: 0.015957326447563387\n",
|
||
"Epoch 80, Train Loss: 0.014635249268281404, Val Loss: 0.015849308388780303\n",
|
||
"Epoch 81, Train Loss: 0.014474964379800849, Val Loss: 0.015526832887597049\n",
|
||
"Epoch 82, Train Loss: 0.014369143295641007, Val Loss: 0.015485089967277512\n",
|
||
"Epoch 83, Train Loss: 0.014446225396076743, Val Loss: 0.015848276135859204\n",
|
||
"Epoch 84, Train Loss: 0.014476079110537419, Val Loss: 0.015343323600158762\n",
|
||
"Epoch 85, Train Loss: 0.014672522836378174, Val Loss: 0.015515949938501885\n",
|
||
"Epoch 86, Train Loss: 0.014440825409545568, Val Loss: 0.015224411166203556\n",
|
||
"Epoch 87, Train Loss: 0.014462759978980111, Val Loss: 0.015663697370397512\n",
|
||
"Epoch 88, Train Loss: 0.01440465696262971, Val Loss: 0.015856551353944773\n",
|
||
"Epoch 89, Train Loss: 0.014255739579146559, Val Loss: 0.015246662380757616\n",
|
||
"Epoch 90, Train Loss: 0.014205876624202756, Val Loss: 0.015011716536732752\n",
|
||
"Epoch 91, Train Loss: 0.014259663818216924, Val Loss: 0.015085076996639593\n",
|
||
"Epoch 92, Train Loss: 0.014251617286978156, Val Loss: 0.015133185506756627\n",
|
||
"Epoch 93, Train Loss: 0.014119144052302723, Val Loss: 0.015415464166496227\n",
|
||
"Epoch 94, Train Loss: 0.014192042264053554, Val Loss: 0.015254960033986995\n",
|
||
"Epoch 95, Train Loss: 0.014140318196855094, Val Loss: 0.017451592276234235\n",
|
||
"Epoch 96, Train Loss: 0.014092271857890502, Val Loss: 0.015359595265072672\n",
|
||
"Epoch 97, Train Loss: 0.01409529693843574, Val Loss: 0.015055305060388437\n",
|
||
"Epoch 98, Train Loss: 0.014136464546688578, Val Loss: 0.015083992547953307\n",
|
||
"Epoch 99, Train Loss: 0.013914715411792103, Val Loss: 0.014718598477653604\n",
|
||
"Epoch 100, Train Loss: 0.013870610982518305, Val Loss: 0.01483591334588492\n",
|
||
"Test Loss: 0.010182651478874807\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model = model.to(device)\n",
|
||
"\n",
|
||
"num_epochs = 100\n",
|
||
"train_losses = list()\n",
|
||
"val_losses = list()\n",
|
||
"for epoch in range(num_epochs):\n",
|
||
" train_loss = train_epoch(model, device, dataloader, criterion, optimizer)\n",
|
||
" train_losses.append(train_loss)\n",
|
||
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
||
" val_losses.append(val_loss)\n",
|
||
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
||
"\n",
|
||
"# 测试模型\n",
|
||
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
||
"print(f'Test Loss: {test_loss}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<matplotlib.legend.Legend at 0x7f6b7999d190>"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3U0lEQVR4nO3deXxU9b3/8fc5M8mEQBJAzIKGxQrIJjsYuD+XGkVUCtpaS7kFtNprC61KbSttpS4PjdWKtFblen0o11bEDdCLKwURkIiAxooLbkiokoBSEhKWJHO+vz/mzGSiBDIkOYcwr+fjcR4hM2f5zjch857POef7tYwxRgAAAD6x/W4AAABIboQRAADgK8IIAADwFWEEAAD4ijACAAB8RRgBAAC+IowAAABfEUYAAICvgn43oCkcx9EXX3yhjIwMWZbld3MAAEATGGO0Z88ede3aVbbdeP2jTYSRL774Qvn5+X43AwAAHIFt27bpxBNPbPT5NhFGMjIyJEVeTGZmps+tAQAATVFZWan8/PzY+3hj2kQYiZ6ayczMJIwAANDGHO4SCy5gBQAAvkoojBQVFWnEiBHKyMhQdna2Jk6cqM2bNx9ym/nz58uyrAZLWlpasxoNAACOHQmFkVdffVXTp0/X66+/rmXLlqm2tlbnnnuuqqurD7ldZmamtm/fHlu2bt3arEYDAIBjR0LXjLz44osNvp8/f76ys7O1ceNGnX766Y1uZ1mWcnNzj6yFAIBjUjgcVm1trd/NQDMEAgEFg8FmD7vRrAtYKyoqJEmdO3c+5HpVVVXq3r27HMfR0KFDddttt6l///6Nrn/gwAEdOHAg9n1lZWVzmgkAOMpUVVXpX//6l4wxfjcFzZSenq68vDylpqYe8T4sc4S/CY7j6Dvf+Y52796tNWvWNLpecXGxPvroI5166qmqqKjQn/70J61atUrvvvtuo/cc33jjjbrpppu+8XhFRQV30wBAGxcOh/XRRx8pPT1dxx9/PINZtlHGGNXU1Gjnzp0Kh8Pq1avXNwY2q6ysVFZW1mHfv484jPz0pz/VCy+8oDVr1hxyIJOvq62tVd++fTVp0iTdcsstB13nYJWR/Px8wggAHAP279+vLVu2qEePHmrXrp3fzUEz7d27V1u3blXPnj2/cYNKU8PIEZ2mmTFjhpYuXapVq1YlFEQkKSUlRUOGDNHHH3/c6DqhUEihUOhImgYAaCOoiBwbDjXMe5P3kcjKxhjNmDFDixcv1ooVK9SzZ8+EDxgOh/XOO+8oLy8v4W0BAMCxJ6HKyPTp07VgwQI988wzysjIUFlZmSQpKysrVmqbMmWKTjjhBBUVFUmSbr75Zp122mk6+eSTtXv3bt15553aunWrrrjiihZ+KQAAoC1KqDJy//33q6KiQmeeeaby8vJiy+OPPx5bp7S0VNu3b499/+9//1tXXnml+vbtq/PPP1+VlZVau3at+vXr13KvAgCANqRHjx6aO3dui+xr5cqVsixLu3fvbpH9+SGhykhTrnVduXJlg+/vvvtu3X333Qk1CgCAo82ZZ56pwYMHt0iIWL9+vdq3b9/8Rh0j2sREea3lwdWf6l//3qdJI7upT+6hZxQEAOBQjDEKh8MKBg//1nr88cd70KK2I6knynvune2av/Yzbf3q0MPZAwBajzFGe2vqfFmaOrrFtGnT9Oqrr+rPf/5zbJ616NxrL7zwgoYNG6ZQKKQ1a9bok08+0YQJE5STk6MOHTpoxIgR+sc//tFgf18/TWNZlh588EFddNFFSk9PV69evfTss88ecZ8+/fTT6t+/v0KhkHr06KG77rqrwfP33XefevXqpbS0NOXk5Oh73/te7LmnnnpKAwcOVLt27XTcccepsLDwsNO+NFdSV0aCduS2MocRAAHAN/tqw+o3+yVfjv3ezWOVnnr4t8I///nP+vDDDzVgwADdfPPNkqR3331XknT99dfrT3/6k0466SR16tRJ27Zt0/nnn69bb71VoVBIjzzyiMaPH6/NmzerW7dujR7jpptu0h133KE777xT99xzjyZPnqytW7cedpTzr9u4caO+//3v68Ybb9Sll16qtWvX6mc/+5mOO+44TZs2TRs2bNAvfvEL/e1vf9Po0aO1a9curV69WpK0fft2TZo0SXfccYcuuugi7dmzR6tXr271kXKTOozY7j3udQ5hBADQuKysLKWmpio9PT0219oHH3wgKXLX6DnnnBNbt3Pnzho0aFDs+1tuuUWLFy/Ws88+qxkzZjR6jGnTpmnSpEmSpNtuu01/+ctf9MYbb+i8885LqK1z5szR2WefrRtuuEGS1Lt3b7333nu68847NW3aNJWWlqp9+/a68MILlZGRoe7du2vIkCGSImGkrq5OF198sbp37y5JGjhwYELHPxJJHUaCgUgYCRNGAMA37VICeu/msb4du7mGDx/e4PuqqirdeOONeu6552Jv7vv27VNpaekh93PqqafG/t2+fXtlZmZqx44dCbfn/fff14QJExo8NmbMGM2dO1fhcFjnnHOOunfvrpNOOknnnXeezjvvvNjpoUGDBunss8/WwIEDNXbsWJ177rn63ve+p06dOiXcjkQk9TUj0coIYQQA/GNZltJTg74sLTEK7Nfvirnuuuu0ePFi3XbbbVq9erVKSko0cOBA1dTUHHI/KSkp3+gXx3Ga3b6vy8jI0JtvvqnHHntMeXl5mj17tgYNGqTdu3crEAho2bJleuGFF9SvXz/dc8896tOnj7Zs2dLi7YiX1GEkes0Ip2kAAIeTmpqqcDh82PVee+01TZs2TRdddJEGDhyo3NxcffbZZ63fQFffvn312muvfaNNvXv3ViAQqQQFg0EVFhbqjjvu0D//+U999tlnWrFihaRICBozZoxuuukmvfXWW0pNTdXixYtbtc1JfZom4I6n7xBGAACH0aNHD61bt06fffaZOnTo0GjVolevXlq0aJHGjx8vy7J0ww03tEqFozG//OUvNWLECN1yyy269NJLVVxcrL/+9a+67777JElLly7Vp59+qtNPP12dOnXS888/L8dx1KdPH61bt07Lly/Xueeeq+zsbK1bt047d+5U3759W7XNSV0ZCbivnsoIAOBwrrvuOgUCAfXr10/HH398o9eAzJkzR506ddLo0aM1fvx4jR07VkOHDvWsnUOHDtUTTzyhhQsXasCAAZo9e7ZuvvlmTZs2TZLUsWNHLVq0SN/+9rfVt29fzZs3T4899pj69++vzMxMrVq1Sueff7569+6t3//+97rrrrs0bty4Vm2zZVr7fp0W0NQpiBM1/dE39dw723XzhP6aUtCjxfYLAGjc/v37tWXLloNOOY+251A/z6a+fyd1ZcSOXjMSPurzGAAAx6ykDiPRC1i5mwYAcLS66qqr1KFDh4MuV111ld/NaxFJfgGrG0aO/jNVAIAkdfPNN+u666476HMteemCn5I7jDDOCADgKJedna3s7Gy/m9Gqkvo0TYARWAEA8F1yhxHmpgEAwHfJHUZiF7B6NxgNAABoKKnDSP3dND43BACAJJbUYYTKCAAA/iOMiMoIAMAbPXr00Ny5c5u0rmVZWrJkSau252hBGBGVEQAA/EQYEYOeAQDgp6QOIwwHDwBHAWOkmmp/lgQ+jD7wwAPq2rWrnK9V0ydMmKDLL79cn3zyiSZMmKCcnBx16NBBI0aM0D/+8Y8W66Z33nlH3/72t9WuXTsdd9xx+slPfqKqqqrY8ytXrtTIkSPVvn17dezYUWPGjNHWrVslSW+//bbOOussZWRkKDMzU8OGDdOGDRtarG3NldQjsDJRHgAcBWr3Srd19efYv/1CSm3fpFUvueQS/fznP9crr7yis88+W5K0a9cuvfjii3r++edVVVWl888/X7feeqtCoZAeeeQRjR8/Xps3b1a3bt2a1czq6mqNHTtWBQUFWr9+vXbs2KErrrhCM2bM0Pz581VXV6eJEyfqyiuv1GOPPaaamhq98cYbstzxtCZPnqwhQ4bo/vvvVyAQUElJiVJSUprVppaU1GEkyGkaAEATderUSePGjdOCBQtiYeSpp55Sly5ddNZZZ8m2bQ0aNCi2/i233KLFixfr2Wef1YwZM5p17AULFmj//v165JFH1L59JDz99a9/1fjx4/XHP/5RKSkpqqio0IUXXqhvfetbkqS+ffvGti8tLdWvfvUrnXLKKZKkXr16Nas9LS2pw4jN3DQA4L+U9EiFwq9jJ2Dy5Mm68sordd999ykUCunRRx/VD37wA9m2raqqKt1444167rnntH37dtXV1Wnfvn0qLS1tdjPff/99DRo0KBZEJGnMmDFyHEebN2/W6aefrmnTpmns2LE655xzVFhYqO9///vKy8uTJM2cOVNXXHGF/va3v6mwsFCXXHJJLLQcDbhmRIQRAPCVZUVOlfixuB9Km2r8+PEyxui5557Ttm3btHr1ak2ePFmSdN1112nx4sW67bbbtHr1apWUlGjgwIGqqalpjV77hocffljFxcUaPXq0Hn/8cfXu3Vuvv/66JOnGG2/Uu+++qwsuuEArVqxQv379tHjxYk/a1RRJHUYChBEAQALS0tJ08cUX69FHH9Vjjz2mPn36aOjQoZKk1157TdOmTdNFF12kgQMHKjc3V5999lmLHLdv3756++23VV1dHXvstddek23b6tOnT+yxIUOGaNasWVq7dq0GDBigBQsWxJ7r3bu3rr32Wr388su6+OKL9fDDD7dI21pCkoeRyMtnojwAQFNNnjxZzz33nB566KFYVUSKXIexaNEilZSU6O2339YPf/jDb9x505xjpqWlaerUqdq0aZNeeeUV/fznP9ePfvQj5eTkaMuWLZo1a5aKi4u1detWvfzyy/roo4/Ut29f7du3TzNmzNDKlSu1detWvfbaa1q/fn2Da0r8ltTXjERP0ziEEQBAE337299W586dtXnzZv3whz+MPT5nzhxdfvnlGj16tLp06aLf/OY3qqysbJFjpqen66WXXtLVV1+tESNGKD09Xd/97nc1Z86c2PMffPCB/vd//1dfffWV8vLyNH36dP3Xf/2X6urq9NVXX2nKlCkqLy9Xly5ddPHFF+umm25qkba1BMuYo/9WksrKSmVlZamiokKZmZkttt+nNv5L1z35ts7ofbz+9/KRLbZfAEDj9u/fry1btqhnz55KS0vzuzlopkP9PJv6/p3Up2lilZGjP48BAHDMSuowwqBnAAA/PProo+rQocNBl/79+/vdPM9xzYi4mwYA4K3vfOc7GjVq1EGfO5pGRvVKUocRJsoDAPghIyNDGRkZfjfjqJHUp2kC7mA33NoLAN5rA/dPoAla4ueY3GEkwK29AOC1QCAgSZ6NTIrWtXfvXknNO72U3KdpqIwAgOeCwaDS09O1c+dOpaSkyLaT+nNxm2WM0d69e7Vjxw517NgxFjKPRFKHkfoLWFtmhDwAwOFZlqW8vDxt2bJFW7du9bs5aKaOHTsqNze3WftI6jDC3DQA4I/U1FT16tWLUzVtXEpKSrMqIlGEERFGAMAPtm0zAiskJfsFrNzaCwCA7wgjksKMwAoAgG8II+JuGgAA/JTUYSTo3k7GRHkAAPgnqcNIwH31VEYAAPBPkoeRyMvnbhoAAPyT3GHE4tZeAAD8ltxhJMAFrAAA+C2pw0h0OHgmygMAwD9JHUZsJsoDAMB3SR1GopURieoIAAB+SeowYseFEaojAAD4I6nDSHxlhDtqAADwR1KHkUB8GGEUVgAAfEEYcTFZHgAA/kjuMGJRGQEAwG9JHUZs21I0j9Q5jr+NAQAgSSUURoqKijRixAhlZGQoOztbEydO1ObNmw+73ZNPPqlTTjlFaWlpGjhwoJ5//vkjbnBLqx/4zOeGAACQpBIKI6+++qqmT5+u119/XcuWLVNtba3OPfdcVVdXN7rN2rVrNWnSJP34xz/WW2+9pYkTJ2rixInatGlTsxvfEuoHPiONAADgB8uYI79YYufOncrOztarr76q008//aDrXHrppaqurtbSpUtjj5122mkaPHiw5s2b16TjVFZWKisrSxUVFcrMzDzS5h5U/9kvqromrFd/daa6H9e+RfcNAEAya+r7d7OuGamoqJAkde7cudF1iouLVVhY2OCxsWPHqri4uNFtDhw4oMrKygZLa4neUcM4IwAA+OOIw4jjOLrmmms0ZswYDRgwoNH1ysrKlJOT0+CxnJwclZWVNbpNUVGRsrKyYkt+fv6RNvOwCCMAAPjriMPI9OnTtWnTJi1cuLAl2yNJmjVrlioqKmLLtm3bWvwYUQE70gXc2gsAgD+CR7LRjBkztHTpUq1atUonnnjiIdfNzc1VeXl5g8fKy8uVm5vb6DahUEihUOhImpawgBvH6hj0DAAAXyRUGTHGaMaMGVq8eLFWrFihnj17HnabgoICLV++vMFjy5YtU0FBQWItbSXBaGWE0zQAAPgiocrI9OnTtWDBAj3zzDPKyMiIXfeRlZWldu3aSZKmTJmiE044QUVFRZKkq6++WmeccYbuuusuXXDBBVq4cKE2bNigBx54oIVfypGJXTPCaRoAAHyRUGXk/vvvV0VFhc4880zl5eXFlscffzy2TmlpqbZv3x77fvTo0VqwYIEeeOABDRo0SE899ZSWLFlyyItevcQFrAAA+CuhykhThiRZuXLlNx675JJLdMkllyRyKM8QRgAA8FdSz00j1U+WRxgBAMAfhBE7Ohw8YQQAAD8kfRgJBqIT5RFGAADwQ9KHkfqJ8ggjAAD4IenDSJALWAEA8FXShxGbMAIAgK+SPowEYxewOj63BACA5JT0YSR6N43DCKwAAPiCMBKtjDBRHgAAvkj6MBKkMgIAgK+SPoxway8AAP5K+jASHfSMu2kAAPBH0oeRgB3pAsIIAAD+IIxECiOEEQAAfEIYoTICAICvCCNuD3ABKwAA/iCMUBkBAMBXSR9GmCgPAAB/JX0YCRBGAADwFWEkGkYYgRUAAF8QRqiMAADgK8IIE+UBAOArwojFRHkAAPiJMBKtjDiOzy0BACA5JX0Yqb+11+eGAACQpJI+jNixMEIaAQDAD0kfRoKx0zRcMwIAgB+SPoxErxlxCCMAAPiCMEJlBAAAXyV9GImepuHWXgAA/JH0YcRm0DMAAHyV9GGEyggAAP5K+jBiW1wzAgCAn5I+jAQDTJQHAICfkj6MBOxIFxBGAADwB2GE0zQAAPiKMMKgZwAA+IowwqBnAAD4KunDSP2svYQRAAD8kPRhJEAYAQDAV4QRwggAAL4ijETDCCOwAgDgC8IIlREAAHxFGIndTeP43BIAAJJT0oeR2ER5ZBEAAHyR9GGkfqI80ggAAH5I+jBSP1Gezw0BACBJJX0Yic5NE6YyAgCALwgjDAcPAICvCCNMlAcAgK8II1RGAADwVdKHkaAd6QKHEVgBAPBF0ocRN4tQGQEAwCdJH0ailRFjuG4EAAA/JH0Yid7aKzFZHgAAfiCMBOLCCJURAAA8l/RhJDo3jUQYAQDAD0kfRuy40zRcxAoAgPcSDiOrVq3S+PHj1bVrV1mWpSVLlhxy/ZUrV8qyrG8sZWVlR9rmFkVlBAAAfyUcRqqrqzVo0CDde++9CW23efNmbd++PbZkZ2cneuhWYRNGAADwVTDRDcaNG6dx48YlfKDs7Gx17Ngx4e28ELQt1TmGMAIAgA88u2Zk8ODBysvL0znnnKPXXnvtkOseOHBAlZWVDZbWFB0Snlt7AQDwXquHkby8PM2bN09PP/20nn76aeXn5+vMM8/Um2++2eg2RUVFysrKii35+fmt2sZYGAkTRgAA8FrCp2kS1adPH/Xp0yf2/ejRo/XJJ5/o7rvv1t/+9reDbjNr1izNnDkz9n1lZWWrBpL6yfKcVjsGAAA4uFYPIwczcuRIrVmzptHnQ6GQQqGQZ+2JhhEmywMAwHu+jDNSUlKivLw8Pw59UMFYZYQwAgCA1xKujFRVVenjjz+Ofb9lyxaVlJSoc+fO6tatm2bNmqXPP/9cjzzyiCRp7ty56tmzp/r376/9+/frwQcf1IoVK/Tyyy+33Ktoptg1I4QRAAA8l3AY2bBhg84666zY99FrO6ZOnar58+dr+/btKi0tjT1fU1OjX/7yl/r888+Vnp6uU089Vf/4xz8a7MNv0cnyCCMAAHjPMubov1CisrJSWVlZqqioUGZmZovv///dsULbdu3T4p+N1pBunVp8/wAAJKOmvn8n/dw0EpURAAD8RBhR/K29hBEAALxGGJEUtCPd4BBGAADwHGFE9ZPlURkBAMB7hBHVjzPC3DQAAHiPMKL6yghz0wAA4D3CiBiBFQAAPxFGxNw0AAD4iTCi+nFGqIwAAOA9woikYMCtjBBGAADwHGFEkk1lBAAA3xBGFHdrr+P43BIAAJIPYURxt/aSRQAA8BxhRFRGAADwE2FE9bf2MmsvAADeI4yIWXsBAPATYURURgAA8BNhRPWDnjFRHgAA3iOMqH7QMybKAwDAe4QRxZ2moTICAIDnCCOKO03DNSMAAHiOMCIpYEe6gbtpAADwHmFEUsDtBSbKAwDAe4QRURkBAMBPhBHFDwdPGAEAwGuEEcVPlEcYAQDAa4QR1VdGOE0DAID3CCOqH2eEC1gBAPAeYURMlAcAgJ8II6o/TeMwAisAAJ4jjEiyLSojAAD4hTCiuInyHMfnlgAAkHwII6qvjHBrLwAA3iOMiEHPAADwE2FE9XfTEEYAAPAeYUTc2gsAgJ8II6IyAgCAnwgjIowAAOAnwoi4gBUAAD8RRhR3ay8jsAIA4DnCiOIHPSOMAADgNcKIpIAd6QbCCAAA3iOMSAowAisAAL4hjIhxRgAA8BNhRPVhxCGMAADgOcKIqIwAAOAnwogYZwQAAD8RRsQIrAAA+IkwIk7TAADgJ8KI4i5gZQRWAAA8RxhRXGUk7PjcEgAAkg9hRPUXsHKWBgAA7xFGVD9RXp1DZQQAAK8RRsREeQAA+IkwIuamAQDAT4QRxd9NIxnuqAEAwFOEEUlBu74bqI4AAOAtwoikuCzCwGcAAHiMMCIqIwAA+CnhMLJq1SqNHz9eXbt2lWVZWrJkyWG3WblypYYOHapQKKSTTz5Z8+fPP4Kmtp74ykiYa0YAAPBUwmGkurpagwYN0r333tuk9bds2aILLrhAZ511lkpKSnTNNdfoiiuu0EsvvZRwY1tLg8pImDACAICXgoluMG7cOI0bN67J68+bN089e/bUXXfdJUnq27ev1qxZo7vvvltjx45N9PCtwr2ZRhKVEQAAvNbq14wUFxersLCwwWNjx45VcXFxo9scOHBAlZWVDZbWZFlW7PZerhkBAMBbrR5GysrKlJOT0+CxnJwcVVZWat++fQfdpqioSFlZWbElPz+/tZtZP1keYQQAAE8dlXfTzJo1SxUVFbFl27ZtrX7M6CisDmEEAABPJXzNSKJyc3NVXl7e4LHy8nJlZmaqXbt2B90mFAopFAq1dtMaCFIZAQDAF61eGSkoKNDy5csbPLZs2TIVFBS09qETEmCyPAAAfJFwGKmqqlJJSYlKSkokRW7dLSkpUWlpqaTIKZYpU6bE1r/qqqv06aef6te//rU++OAD3XfffXriiSd07bXXtswraCFMlgcAgD8SDiMbNmzQkCFDNGTIEEnSzJkzNWTIEM2ePVuStH379lgwkaSePXvqueee07JlyzRo0CDdddddevDBB4+a23qj6i9gdXxuCQAAySXha0bOPPPMQ85se7DRVc8880y99dZbiR7KU7GZe8kiAAB46qi8m8YPVEYAAPAHYcQVq4wwAisAAJ4ijLhilRHmpgEAwFOEEVeQ4eABAPAFYcRlR2/t5TQNAACeIoy4ggFGYAUAwA+EERdz0wAA4A/CiItZewEA8AdhxBW0I11BZQQAAG8RRlxuFqEyAgCAxwgjrmhlhFt7AQDwFmHEZTPOCAAAviCMuBj0DAAAfxBGXNG7aRj0DAAAbxFGXNFxRriAFQAAbxFGXAF3BNZw2PG5JQAAJBfCiCsQm5vG54YAAJBkCCOu+gtYqYwAAOAlwogrdgErWQQAAE8RRlwBKiMAAPiCMOJiojwAAPxBGHFFwwgT5QEA4C3CiIvKCAAA/iCMuIKMwAoAgC8II67YRHkMNAIAgKcII64gp2kAAPAFYcQVHYHV4TQNAACeIoy4AnakK6iMAADgLcKIK+D2BLf2AgDgLcKIi8oIAAD+IIy46ifKI4wAAOAlwojLJowAAOALwoiLyggAAP4gjLiojAAA4A/CiItBzwAA8AdhxBWIVUYcn1sCAEByIYy4oiOwMjUNAADeIoy4ggEqIwAA+IEw4rItLmAFAMAPhBEXt/YCAOAPwogrwN00AAD4gjDiioYRJsoDAMBbhBEXlREAAPxBGHEFuGYEAABfEEZchBEAAPxBGHEF7UhXEEYAAPAWYcQVcHsibAgjAAB4iTDiCriVkTrGgwcAwFOEEVd0bhqHyggAAJ4ijLi4tRcAAH8QRlzRifIY9AwAAG8RRlzRifKojAAA4C3CiIuJ8gAA8AdhxMWgZwAA+IMw4iKMAADgD8KIK3aahlt7AQDwFGHEZcdVRgyBBAAAzxBGXNHKiMSpGgAAvHREYeTee+9Vjx49lJaWplGjRumNN95odN358+fLsqwGS1pa2hE3uLXY8WGEyggAAJ5JOIw8/vjjmjlzpv7whz/ozTff1KBBgzR27Fjt2LGj0W0yMzO1ffv22LJ169ZmNbo1UBkBAMAfCYeROXPm6Morr9Rll12mfv36ad68eUpPT9dDDz3U6DaWZSk3Nze25OTkNKvRrSE66JlEGAEAwEsJhZGamhpt3LhRhYWF9TuwbRUWFqq4uLjR7aqqqtS9e3fl5+drwoQJevfddw95nAMHDqiysrLB0tqojAAA4I+EwsiXX36pcDj8jcpGTk6OysrKDrpNnz599NBDD+mZZ57R3//+dzmOo9GjR+tf//pXo8cpKipSVlZWbMnPz0+kmUckEBdGGBIeAADvtPrdNAUFBZoyZYoGDx6sM844Q4sWLdLxxx+v//7v/250m1mzZqmioiK2bNu2rbWbKcuyFM0jTJYHAIB3goms3KVLFwUCAZWXlzd4vLy8XLm5uU3aR0pKioYMGaKPP/640XVCoZBCoVAiTWsRQdtWTdihMgIAgIcSqoykpqZq2LBhWr58eewxx3G0fPlyFRQUNGkf4XBY77zzjvLy8hJrqQdstze4ZgQAAO8kVBmRpJkzZ2rq1KkaPny4Ro4cqblz56q6ulqXXXaZJGnKlCk64YQTVFRUJEm6+eabddppp+nkk0/W7t27deedd2rr1q264oorWvaVtICgbUtyCCMAAHgo4TBy6aWXaufOnZo9e7bKyso0ePBgvfjii7GLWktLS2Xb9QWXf//737ryyitVVlamTp06adiwYVq7dq369evXcq+ihUQvYuU0DQAA3rFMG5iIpbKyUllZWaqoqFBmZmarHWfoLcu0q7pGL197unrnZLTacQAASAZNff9mbpo4scpI+KjPZwAAHDMII3EC7iisztFfLAIA4JhBGInDNSMAAHiPMBInGIiEkbDj+NwSAACSB2EkTvQ0TZgsAgCAZwgjcepP05BGAADwCmEkTjSMkEUAAPAOYSQOlREAALxHGIkTtKPXjHA3DQAAXiGMxLEJIwAAeI4wEofKCAAA3iOMxLGjt/YyAisAAJ4hjMSpH/SMMAIAgFcII3ECdqQ7mCgPAADvEEbiuIURTtMAAOAhwkicaGWE0zQAAHiHMBIn4PYGYQQAAO8QRuIEqYwAAOA5wkic+uHgCSMAAHiFMBKnfqI8wggAAF4hjMShMgIAgPcII3EC7gisDrf2AgDgGcJInIA70AiDngEA4B3CSJwAc9MAAOA5wkicQGzWXsfnlgAAkDwII3GCXMAKAIDnCCNxuLUXAADvEUbicGsvAADeI4zEoTICAID3CCNxqIwAAOA9wkicYOxuGsIIAABeIYzEsQkjAAB4jjASh8oIAADeI4zEsRmBFQAAzxFG4jDoGQAA3iOMxAkEIt0RZqI8AAA8QxiJw0R5AAB4L3nDiONImxZJT0yV9ldI4gJWAAD8kLxhxLKklbdL7y2RPnheErf2AgDgh+QOIwMujvz73UWSqIwAAOCH5A0jktTfDSOfrJD27oobDt7xsVEAACSX5A4jx/eWcgZKTp30/v/FTZTnc7sAAEgiyR1GJGnARZGvm56mMgIAgA8II9FTNZ+tVrsDX0mSGGYEAADvEEY695S6DpWMo9zPX5IkhamMAADgGcKIJA34riQppzRyi28dpREAADxDGJGk/hMlSVk71ytXX8lhBFYAADxDGJGkrBOlbgWSpAsC65goDwAADxFGotwLWccHilW5r077asI+NwgAgORAGInqN0HGsjXY/kRp1dv04/9dTyABAMADhJGojBxZPf5DkvTdlHVa+8lXumz+G9pbU+dzwwAAOLYRRuK5p2p+kbpEt4YeUfmWdzXtofWqPkAgAQCgtRBG4g38nnTiCAXq9mmy9aJeCf1SV31+vebcf7/e+9cuGe6yAQCgxVmmDbzDVlZWKisrSxUVFcrMzGzdgxkjfbpSWjdP+vDF+jaYdJWkDNL+Hmer1+iJ6nlSr9ZtBwAAbVxT378JI4fy1Sfa9co9Sn3vaXVwKhs8VWp11ZcZp8jknqrOJ4/QiX1PU0pGF+/aBgDAUY4w0pKcsKo/W6/Sdc8odcsK9TywWbb1zW7bbWXpy7Tu2pd5kuzje6t9Xm91OqGXMnN7ykrL8r7dAAD4iDDSmu35qlxb3lmjPVs2KmXnO8rbu1ndVH7obawM7QrmaF9aturSs6WMXKV2zFO7TrnqkNVFGR27KKV9RyktS0ppLwVSJMvy5gUBANAKCCMechyjbWU7VL5lk/Z8/r7Mzg/VrnKLOtV8rlynXJ2tqoT3GZatWiukWjtN4UBIjp0iY6fKBFKlQIqcYDs5KR3kpLSXSe0gpbaXFQy5S6rsYEh2MEV2ICA7kCI7EFQgGFQgGFIgNaRAME12SsgNPQHJsiXbjnyVItfOyP3VsGwp2E4KhqRgWuSrHZTsQGTb6FfLIkABAGKa+v4dPJKd33vvvbrzzjtVVlamQYMG6Z577tHIkSMbXf/JJ5/UDTfcoM8++0y9evXSH//4R51//vlHcuijkm1b6t41R9275kg6u8Fz+2vD+mznl9r1+ceqKvtUNRXbZfaUKbi3XGn7v1R67b/VzqlSprVXmdqrdOuAJCkgRwGzT2nhfVIbG3vNkSUjS0a2jOV+lSXHCsixAgorGPu3YwVkLFuSLeP+O/LVkrECkqy4x2wZRb7achQwdbJNWLbCsuTIKBKIoutZVqQVkhX5alkydoqMnRIJd4FUd7/1x42EMSsuWEW2tU1YlnFkKSzLGMkOyNgpkh2UsYOSFZSxbfdrZD+2HFlOnbttnSwpEv4CqTKB+iDY4NiWLUtOpMcsRbaxbDf8pcSFQFuyLFmxr5G+txTZyJIkOygrEJQVSJHlbmdblix3fUu2LMtIxqlfZEX2bwdlByIh07YDsmxbltsfkiLrOnX1i3v0aLvq+9GuX2w78hoC0dcRjPVxZHPL3W847ms4sl4gxd02NdK+qOgLN6Z+fScsycQFZftrS1z7jBNZ1zhuAD8Y0/A5y6r/eQRSGranOUx8O0zczyMY6btEOWEpXBNZHLcfg6FIH/Khoe1xwlLt3kjl/Eh+H45yCYeRxx9/XDNnztS8efM0atQozZ07V2PHjtXmzZuVnZ39jfXXrl2rSZMmqaioSBdeeKEWLFigiRMn6s0339SAAQNa5EUczdJSAurRNUc9uuZIGnPQdRzHqHJ/rb6oqtG/K6tUXbVHe/dVaf/eKh3YW6Wa/XsVrj2g2pr9cmprFK7dL7tun4J11Up19iq1bq9Snb2yTZ1sp1ZBp0ZBUyvbhCWFFXDfSIMKK9WqU6pqlao6pahOQYUjb+xyZFtGthwZ980h+jWosFJVqzTVKk01B71eJp6taFXFiRVX3B0CxyRHlvt7f/j1ov+vJEUj+2G3C8uO/I+1AnJiIzLU7yvyfzgs2zjuv+sOud8apcixApE9GCf6cUF1VorqlKJaK0W1VqrCVsA90qHVHykuWMaei/w7EuYly+2Fhs+7r8X65mPGsiPx3LJjH3Di22NJskzk71jkGI6M7MiHHQUUjvWZkR0Nn9Fju/uWFekBW07sA060LyUjy5jYR6zI/iIfqMJW5Gu0Xfraz7e+jdFjmvrvTeTnFvvwIRPJoVakhyRLqc4+tQvvUXp4j9KcakmSI1v7A+21P9BB++wM1doht28iH+CcRn5asQ917gc/y/07bUmS+/oyJvxRed86tbEfc6tK+DTNqFGjNGLECP31r3+VJDmOo/z8fP385z/X9ddf/431L730UlVXV2vp0qWxx0477TQNHjxY8+bNa9Ixj/bTNG2BMUZ1jlFd2KjWcVQXNqoLO6p13K9hR7XhyPOOiS6R7cKOUdgYOY4UdhyZcK2cujqFnToZJywnXCfjOHKMI8dxZBwjxwlHvg87Mk6dHMeRE45UCGwTluXUyTK1spz6dcJOWMapcz/hRj7pWnI/ISssyzGyFHkubOzIqSz3z65jrNgfVsu4lRIj93VY7mtxZJs6BZxa2aY28lWOAu4fsugfIMlE/mS4n1AdWXJMQI6ksCJfLRPdLhxXoXHcP2Z1suXIka1aE/ljWGsCkjEKKKygqVVQdUoxtfVBMG5xZEfaLkvGRP5wRepJ4djX+D+Q8QFSqv9Ta8tRiuoUkKMUt3oUXSv6h7X+bSjySHR/AXcJKhz781pfaVKs/yMtsuP2Wb/vuPqYApbjtqf+daS4+45sW68u8pNwe8NS0G1HiuqUYh26TOgYS2H3bdiWUdByjvj/DJBsPrhwkU4ZfvbhV0xAq5ymqamp0caNGzVr1qzYY7Ztq7CwUMXFxQfdpri4WDNnzmzw2NixY7VkyZJGj3PgwAEdOHAg9n1lZWWj66JpLMtSSsBSSkBqpxYqK+OoZ9xQGQ2Y0Y8e0aD2jfXdbYy7jnG3cdzHovsIxD/mznJtWXJPA0X25ZjIc44xqnEi6+47ZFsjLTBGseM7xjQ8BWWMZMKRuOi2xRgT+WQbPQUT21f9jo17Csc4TiT+OGEZOTLG3c597ZZVH9dir8MNhUaWwiayrZxaKVwnhWtlnLpYeHUOEgqNe5bHanC6J1o9jJ42qo9xkahpx/rWNmEFTE3klJ9TFwnoxriDMEa+Gisgx7JlrGAkYNpBOdHTkVaKHDsg26mT7dTIdmoUCB+QZcJxlQE3+Dp1ssL7ZYdrZIVrIq/T7Q9J7qf/yM/Cck8lRl9T7PfDqX+Nsc+67s+o/vSiG12NcX+WkZ6Of12R7RzJcSKn39yf/dd/wMZIsiOnZp3YqU4TqZbELfWnNetfT+z3yXFkjCPZwdipYxNX8YhULNyKlBOWZWplO3WSCcty3FDtVmXcH7bbT/WnU+o/MES+Olb052w3+DBQX1sxqrPTdCAlUweCGToQzFSNnabUcLVSa6uUWrdHobpKBZyaSHXFPdUaqeZ8vToS/XAXlpzIBzbjno6O/m5L0uCu/o2flVAY+fLLLxUOh5WTk9Pg8ZycHH3wwQcH3aasrOyg65eVlTV6nKKiIt10002JNA3AQViWpYAlBQ5baAcA/xyVV8HMmjVLFRUVsWXbtm1+NwkAALSShCojXbp0USAQUHl5wzE1ysvLlZube9BtcnNzE1pfkkKhkEKhUCJNAwAAbVRClZHU1FQNGzZMy5cvjz3mOI6WL1+ugoKCg25TUFDQYH1JWrZsWaPrAwCA5JLwrb0zZ87U1KlTNXz4cI0cOVJz585VdXW1LrvsMknSlClTdMIJJ6ioqEiSdPXVV+uMM87QXXfdpQsuuEALFy7Uhg0b9MADD7TsKwEAAG1SwmHk0ksv1c6dOzV79myVlZVp8ODBevHFF2MXqZaWlsqOG5Bl9OjRWrBggX7/+9/rt7/9rXr16qUlS5YkxRgjAADg8BgOHgAAtIqmvn8flXfTAACA5EEYAQAAviKMAAAAXxFGAACArwgjAADAV4QRAADgK8IIAADwVcKDnvkhOhRKZWWlzy0BAABNFX3fPtyQZm0ijOzZs0eSlJ+f73NLAABAovbs2aOsrKxGn28TI7A6jqMvvvhCGRkZsiyrxfZbWVmp/Px8bdu2jZFdWxl97R362lv0t3foa++0VF8bY7Rnzx517dq1wVQxX9cmKiO2bevEE09stf1nZmbyi+0R+to79LW36G/v0NfeaYm+PlRFJIoLWAEAgK8IIwAAwFdJHUZCoZD+8Ic/KBQK+d2UYx597R362lv0t3foa+943ddt4gJWAABw7ErqyggAAPAfYQQAAPiKMAIAAHxFGAEAAL5K6jBy7733qkePHkpLS9OoUaP0xhtv+N2kNq+oqEgjRoxQRkaGsrOzNXHiRG3evLnBOvv379f06dN13HHHqUOHDvrud7+r8vJyn1p8bLj99ttlWZauueaa2GP0c8v6/PPP9Z//+Z867rjj1K5dOw0cOFAbNmyIPW+M0ezZs5WXl6d27dqpsLBQH330kY8tbpvC4bBuuOEG9ezZU+3atdO3vvUt3XLLLQ3mNqGvj8yqVas0fvx4de3aVZZlacmSJQ2eb0q/7tq1S5MnT1ZmZqY6duyoH//4x6qqqmp+40ySWrhwoUlNTTUPPfSQeffdd82VV15pOnbsaMrLy/1uWps2duxY8/DDD5tNmzaZkpISc/7555tu3bqZqqqq2DpXXXWVyc/PN8uXLzcbNmwwp512mhk9erSPrW7b3njjDdOjRw9z6qmnmquvvjr2OP3ccnbt2mW6d+9upk2bZtatW2c+/fRT89JLL5mPP/44ts7tt99usrKyzJIlS8zbb79tvvOd75iePXuaffv2+djytufWW281xx13nFm6dKnZsmWLefLJJ02HDh3Mn//859g69PWRef75583vfvc7s2jRIiPJLF68uMHzTenX8847zwwaNMi8/vrrZvXq1ebkk082kyZNanbbkjaMjBw50kyfPj32fTgcNl27djVFRUU+turYs2PHDiPJvPrqq8YYY3bv3m1SUlLMk08+GVvn/fffN5JMcXGxX81ss/bs2WN69eplli1bZs4444xYGKGfW9ZvfvMb8x//8R+NPu84jsnNzTV33nln7LHdu3ebUChkHnvsMS+aeMy44IILzOWXX97gsYsvvthMnjzZGENft5Svh5Gm9Ot7771nJJn169fH1nnhhReMZVnm888/b1Z7kvI0TU1NjTZu3KjCwsLYY7Ztq7CwUMXFxT627NhTUVEhSercubMkaePGjaqtrW3Q96eccoq6detG3x+B6dOn64ILLmjQnxL93NKeffZZDR8+XJdccomys7M1ZMgQ/c///E/s+S1btqisrKxBf2dlZWnUqFH0d4JGjx6t5cuX68MPP5Qkvf3221qzZo3GjRsnib5uLU3p1+LiYnXs2FHDhw+PrVNYWCjbtrVu3bpmHb9NTJTX0r788kuFw2Hl5OQ0eDwnJ0cffPCBT6069jiOo2uuuUZjxozRgAEDJEllZWVKTU1Vx44dG6ybk5OjsrIyH1rZdi1cuFBvvvmm1q9f/43n6OeW9emnn+r+++/XzJkz9dvf/lbr16/XL37xC6Wmpmrq1KmxPj3Y3xT6OzHXX3+9KisrdcoppygQCCgcDuvWW2/V5MmTJYm+biVN6deysjJlZ2c3eD4YDKpz587N7vukDCPwxvTp07Vp0yatWbPG76Ycc7Zt26arr75ay5YtU1pamt/NOeY5jqPhw4frtttukyQNGTJEmzZt0rx58zR16lSfW3dseeKJJ/Too49qwYIF6t+/v0pKSnTNNdeoa9eu9PUxLClP03Tp0kWBQOAbdxaUl5crNzfXp1YdW2bMmKGlS5fqlVde0Yknnhh7PDc3VzU1Ndq9e3eD9en7xGzcuFE7duzQ0KFDFQwGFQwG9eqrr+ovf/mLgsGgcnJy6OcWlJeXp379+jV4rG/fviotLZWkWJ/yN6X5fvWrX+n666/XD37wAw0cOFA/+tGPdO2116qoqEgSfd1amtKvubm52rFjR4Pn6+rqtGvXrmb3fVKGkdTUVA0bNkzLly+PPeY4jpYvX66CggIfW9b2GWM0Y8YMLV68WCtWrFDPnj0bPD9s2DClpKQ06PvNmzertLSUvk/A2WefrXfeeUclJSWxZfjw4Zo8eXLs3/RzyxkzZsw3blH/8MMP1b17d0lSz549lZub26C/KysrtW7dOvo7QXv37pVtN3xrCgQCchxHEn3dWprSrwUFBdq9e7c2btwYW2fFihVyHEejRo1qXgOadflrG7Zw4UITCoXM/PnzzXvvvWd+8pOfmI4dO5qysjK/m9am/fSnPzVZWVlm5cqVZvv27bFl7969sXWuuuoq061bN7NixQqzYcMGU1BQYAoKCnxs9bEh/m4aY+jnlvTGG2+YYDBobr31VvPRRx+ZRx991KSnp5u///3vsXVuv/1207FjR/PMM8+Yf/7zn2bChAncbnoEpk6dak444YTYrb2LFi0yXbp0Mb/+9a9j69DXR2bPnj3mrbfeMm+99ZaRZObMmWPeeusts3XrVmNM0/r1vPPOM0OGDDHr1q0za9asMb169eLW3ua65557TLdu3UxqaqoZOXKkef311/1uUpsn6aDLww8/HFtn37595mc/+5np1KmTSU9PNxdddJHZvn27f40+Rnw9jNDPLev//u//zIABA0woFDKnnHKKeeCBBxo87ziOueGGG0xOTo4JhULm7LPPNps3b/aptW1XZWWlufrqq023bt1MWlqaOemkk8zvfvc7c+DAgdg69PWReeWVVw7693nq1KnGmKb161dffWUmTZpkOnToYDIzM81ll11m9uzZ0+y2WcbEDWsHAADgsaS8ZgQAABw9CCMAAMBXhBEAAOArwggAAPAVYQQAAPiKMAIAAHxFGAEAAL4ijAAAAF8RRgAAgK8IIwAAwFeEEQAA4CvCCAAA8NX/B/RM7in27jEDAAAAAElFTkSuQmCC",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"tr_ind = list(range(len(train_losses)))\n",
|
||
"val_ind = list(range(len(val_losses)))\n",
|
||
"plt.plot(train_losses, label='train_loss')\n",
|
||
"plt.plot(val_losses, label='val_loss')\n",
|
||
"plt.legend(loc='best')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"id": "8be48f80-a6e6-4b05-87ef-3adbf0bef576",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"id": "cff8cba9-aba9-4347-8e1a-f169df8313c2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"eva_list = list()\n",
|
||
"device = 'cpu'\n",
|
||
"model = model.to(device)\n",
|
||
"with torch.no_grad():\n",
|
||
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
||
" reconstructed = model(X)\n",
|
||
" # tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
|
||
" # tr_mins = np.transpose(mins, (2, 0, 1))\n",
|
||
" rev_data = y * max_pixel_value\n",
|
||
" rev_recon = reconstructed * max_pixel_value\n",
|
||
" # todo: 这里需要只评估修补出来的模块\n",
|
||
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
||
" data_label = data_label[mask_rev==1]\n",
|
||
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
||
" recon_no2 = recon_no2[mask_rev==1]\n",
|
||
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
||
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
||
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
||
" r2 = r2_score(data_label, recon_no2)\n",
|
||
" eva_list.append([mae, rmse, mape, r2])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"id": "edd09b0b-4496-4b88-a581-d1203aad05ce",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>mae</th>\n",
|
||
" <th>rmse</th>\n",
|
||
" <th>mape</th>\n",
|
||
" <th>r2</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>count</th>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" <td>75.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>mean</th>\n",
|
||
" <td>1.669181</td>\n",
|
||
" <td>2.722375</td>\n",
|
||
" <td>0.228690</td>\n",
|
||
" <td>0.825025</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>std</th>\n",
|
||
" <td>0.101549</td>\n",
|
||
" <td>0.229373</td>\n",
|
||
" <td>0.027824</td>\n",
|
||
" <td>0.023960</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>min</th>\n",
|
||
" <td>1.456919</td>\n",
|
||
" <td>2.206495</td>\n",
|
||
" <td>0.147438</td>\n",
|
||
" <td>0.751642</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25%</th>\n",
|
||
" <td>1.600787</td>\n",
|
||
" <td>2.569844</td>\n",
|
||
" <td>0.210564</td>\n",
|
||
" <td>0.815437</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>50%</th>\n",
|
||
" <td>1.663539</td>\n",
|
||
" <td>2.723380</td>\n",
|
||
" <td>0.228493</td>\n",
|
||
" <td>0.826285</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>75%</th>\n",
|
||
" <td>1.726697</td>\n",
|
||
" <td>2.848122</td>\n",
|
||
" <td>0.248380</td>\n",
|
||
" <td>0.837574</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>max</th>\n",
|
||
" <td>1.998206</td>\n",
|
||
" <td>3.443690</td>\n",
|
||
" <td>0.287797</td>\n",
|
||
" <td>0.881901</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" mae rmse mape r2\n",
|
||
"count 75.000000 75.000000 75.000000 75.000000\n",
|
||
"mean 1.669181 2.722375 0.228690 0.825025\n",
|
||
"std 0.101549 0.229373 0.027824 0.023960\n",
|
||
"min 1.456919 2.206495 0.147438 0.751642\n",
|
||
"25% 1.600787 2.569844 0.210564 0.815437\n",
|
||
"50% 1.663539 2.723380 0.228493 0.826285\n",
|
||
"75% 1.726697 2.848122 0.248380 0.837574\n",
|
||
"max 1.998206 3.443690 0.287797 0.881901"
|
||
]
|
||
},
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"id": "7e0a48b0-be9a-429b-a77f-3fe413c1aae7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def cal_ioa(y_true, y_pred):\n",
|
||
" # 计算平均值\n",
|
||
" mean_observed = np.mean(y_true)\n",
|
||
" mean_predicted = np.mean(y_pred)\n",
|
||
"\n",
|
||
" # 计算IoA\n",
|
||
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
||
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
||
" IoA = 1 - (numerator / denominator)\n",
|
||
"\n",
|
||
" return IoA"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"id": "1263f067-2d88-4321-900d-29aa2a84df12",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"eva_list_frame = list()\n",
|
||
"device = 'cpu'\n",
|
||
"model = model.to(device)\n",
|
||
"with torch.no_grad():\n",
|
||
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
||
" reconstructed = model(X)\n",
|
||
" rev_data = y * max_pixel_value\n",
|
||
" rev_recon = reconstructed * max_pixel_value\n",
|
||
" # todo: 这里需要只评估修补出来的模块\n",
|
||
" for i, sample in enumerate(rev_data):\n",
|
||
" used_mask = mask_rev[i]\n",
|
||
" data_label = sample[0] * used_mask\n",
|
||
" recon_no2 = rev_recon[i][0] * used_mask\n",
|
||
" data_label = data_label[used_mask==1]\n",
|
||
" recon_no2 = recon_no2[used_mask==1]\n",
|
||
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
||
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
||
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
||
" r2 = r2_score(data_label, recon_no2)\n",
|
||
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
||
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
|
||
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"id": "27289a64-0405-48e3-bec3-ad0a612988a6",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>mae</th>\n",
|
||
" <th>rmse</th>\n",
|
||
" <th>mape</th>\n",
|
||
" <th>r2</th>\n",
|
||
" <th>ioa</th>\n",
|
||
" <th>r</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>count</th>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" <td>4739.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>mean</th>\n",
|
||
" <td>1.657513</td>\n",
|
||
" <td>2.352886</td>\n",
|
||
" <td>0.232260</td>\n",
|
||
" <td>0.406715</td>\n",
|
||
" <td>0.823672</td>\n",
|
||
" <td>0.747644</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>std</th>\n",
|
||
" <td>0.897659</td>\n",
|
||
" <td>1.318793</td>\n",
|
||
" <td>0.234080</td>\n",
|
||
" <td>0.877368</td>\n",
|
||
" <td>0.184708</td>\n",
|
||
" <td>0.191853</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>min</th>\n",
|
||
" <td>0.546354</td>\n",
|
||
" <td>0.695038</td>\n",
|
||
" <td>0.066870</td>\n",
|
||
" <td>-30.315991</td>\n",
|
||
" <td>-1.254103</td>\n",
|
||
" <td>-0.392216</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25%</th>\n",
|
||
" <td>1.042898</td>\n",
|
||
" <td>1.472388</td>\n",
|
||
" <td>0.137216</td>\n",
|
||
" <td>0.313886</td>\n",
|
||
" <td>0.782728</td>\n",
|
||
" <td>0.671206</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>50%</th>\n",
|
||
" <td>1.465436</td>\n",
|
||
" <td>2.072718</td>\n",
|
||
" <td>0.174430</td>\n",
|
||
" <td>0.610684</td>\n",
|
||
" <td>0.879875</td>\n",
|
||
" <td>0.805015</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>75%</th>\n",
|
||
" <td>1.976618</td>\n",
|
||
" <td>2.785021</td>\n",
|
||
" <td>0.234618</td>\n",
|
||
" <td>0.757136</td>\n",
|
||
" <td>0.929170</td>\n",
|
||
" <td>0.879403</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>max</th>\n",
|
||
" <td>9.007959</td>\n",
|
||
" <td>12.398485</td>\n",
|
||
" <td>3.290891</td>\n",
|
||
" <td>0.973600</td>\n",
|
||
" <td>0.993247</td>\n",
|
||
" <td>0.987535</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" mae rmse mape r2 ioa \\\n",
|
||
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
|
||
"mean 1.657513 2.352886 0.232260 0.406715 0.823672 \n",
|
||
"std 0.897659 1.318793 0.234080 0.877368 0.184708 \n",
|
||
"min 0.546354 0.695038 0.066870 -30.315991 -1.254103 \n",
|
||
"25% 1.042898 1.472388 0.137216 0.313886 0.782728 \n",
|
||
"50% 1.465436 2.072718 0.174430 0.610684 0.879875 \n",
|
||
"75% 1.976618 2.785021 0.234618 0.757136 0.929170 \n",
|
||
"max 9.007959 12.398485 3.290891 0.973600 0.993247 \n",
|
||
"\n",
|
||
" r \n",
|
||
"count 4739.000000 \n",
|
||
"mean 0.747644 \n",
|
||
"std 0.191853 \n",
|
||
"min -0.392216 \n",
|
||
"25% 0.671206 \n",
|
||
"50% 0.805015 \n",
|
||
"75% 0.879403 \n",
|
||
"max 0.987535 "
|
||
]
|
||
},
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"id": "c72964bf-bbc5-4773-bd5f-6a0ea674934e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe().to_csv('./eva_files/baseline_mask_loss.csv', encoding='utf-8-sig')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "NameError",
|
||
"evalue": "name 'data' is not defined",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m visualize_feature(\u001b[43mdata\u001b[49m[\u001b[38;5;241m5\u001b[39m], masked_data[\u001b[38;5;241m5\u001b[39m], reconstructed[\u001b[38;5;241m5\u001b[39m], \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNO2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
||
"\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.16"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|