MAE_ATMO/torch_MAE_1d_baseline.ipynb

896 lines
53 KiB
Plaintext
Raw Permalink Normal View History

2024-11-21 14:02:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 30,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"from PIL import Image\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e0afbbc4-cd35-49f7-986f-2c0a6fff5ec1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f6d6be638f0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(0)\n",
"torch.random.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "95baeec7-508b-480c-b598-aecab7497a99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum pixel value in the dataset: 107.49169921875\n"
]
}
],
"source": [
"# 定义函数来找到最大值\n",
"def find_max_pixel_value(image_dir):\n",
" max_pixel_value = 0.0\n",
" for filename in os.listdir(image_dir):\n",
" if filename.endswith('.npy'):\n",
" image_path = os.path.join(image_dir, filename)\n",
" image = np.load(image_path).astype(np.float32)\n",
" max_pixel_value = max(max_pixel_value, image[:, :, 0].max())\n",
" return max_pixel_value\n",
"\n",
"# 计算图像数据中的最大像素值\n",
"image_dir = './out_mat/96/train/' \n",
"max_pixel_value = find_max_pixel_value(image_dir)\n",
"\n",
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9a8fe22d-5029-427f-bae8-01934a0d5c35",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"image_dir = './out_mat/96/train/'\n",
"mask_dir = './out_mat/96/mask/20/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ddbc13ba-a0e8-477e-895e-371a78085bac",
"metadata": {},
"outputs": [],
"source": [
"dataset = NO2Dataset(image_dir, mask_dir)\n",
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
"val_set = NO2Dataset('./out_mat/96/valid/', mask_dir)\n",
"val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=4)\n",
"test_set = NO2Dataset('./out_mat/96/test/', mask_dir)\n",
"test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# 设置随机种子以确保结果的可重复性\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)\n",
"\n",
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f821d0a1-dfee-483e-b081-68c963bdb8a0",
"metadata": {},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoderBase(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoderBase, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2dc47416-511e-4874-abaf-30dd912a0e7d",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2e77e837-071c-46d0-9779-80bb333db800",
"metadata": {},
"outputs": [],
"source": [
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoderBase()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(X)\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" # loss = criterion(reconstructed, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(data_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" reconstructed = model(X)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(y))\n",
" # visualize_feature(y[rand_ind], X[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = masked_mse_loss(reconstructed, y, mask)\n",
" # loss = criterion(reconstructed, y)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 2.4377448216936233, Val Loss: 0.1723788405087457\n",
"Epoch 2, Train Loss: 0.09637997932197374, Val Loss: 0.07621741728551353\n",
"Epoch 3, Train Loss: 0.06397618102934657, Val Loss: 0.06195200451496822\n",
"Epoch 4, Train Loss: 0.052692288621974906, Val Loss: 0.052603201690449644\n",
"Epoch 5, Train Loss: 0.045533701719529036, Val Loss: 0.0462518873721806\n",
"Epoch 6, Train Loss: 0.040426999678095564, Val Loss: 0.04118765834996949\n",
"Epoch 7, Train Loss: 0.03643315702979787, Val Loss: 0.0370612932619319\n",
"Epoch 8, Train Loss: 0.03297993362074691, Val Loss: 0.0338741072189452\n",
"Epoch 9, Train Loss: 0.030229569176595177, Val Loss: 0.03180063916231269\n",
"Epoch 10, Train Loss: 0.028299911767600827, Val Loss: 0.03058352780097456\n",
"Epoch 11, Train Loss: 0.026935207724357337, Val Loss: 0.029766072282817826\n",
"Epoch 12, Train Loss: 0.026076676769618782, Val Loss: 0.028107319638800265\n",
"Epoch 13, Train Loss: 0.02534967841821139, Val Loss: 0.027272115475428637\n",
"Epoch 14, Train Loss: 0.024701394381349166, Val Loss: 0.02684043228292643\n",
"Epoch 15, Train Loss: 0.0240272392550011, Val Loss: 0.02594853615138068\n",
"Epoch 16, Train Loss: 0.0233813104438083, Val Loss: 0.025640942656726978\n",
"Epoch 17, Train Loss: 0.02310016915273438, Val Loss: 0.02571806650775582\n",
"Epoch 18, Train Loss: 0.022718923658792054, Val Loss: 0.024644668200122778\n",
"Epoch 19, Train Loss: 0.022323213453052576, Val Loss: 0.024273945435659208\n",
"Epoch 20, Train Loss: 0.02199719715685223, Val Loss: 0.02410240029332353\n",
"Epoch 21, Train Loss: 0.021530815467024535, Val Loss: 0.02380427871066243\n",
"Epoch 22, Train Loss: 0.021460241776262743, Val Loss: 0.0232450627346537\n",
"Epoch 23, Train Loss: 0.02090326771050977, Val Loss: 0.022885078564286232\n",
"Epoch 24, Train Loss: 0.020652044475363774, Val Loss: 0.022562191390724323\n",
"Epoch 25, Train Loss: 0.02051923798985387, Val Loss: 0.022203324724044373\n",
"Epoch 26, Train Loss: 0.020149177833767743, Val Loss: 0.022744494337421744\n",
"Epoch 27, Train Loss: 0.020068248640300268, Val Loss: 0.022425833088693333\n",
"Epoch 28, Train Loss: 0.019720358143529397, Val Loss: 0.02253118777341807\n",
"Epoch 29, Train Loss: 0.01939903690288084, Val Loss: 0.021765351378873213\n",
"Epoch 30, Train Loss: 0.01943497322989922, Val Loss: 0.021345259649540062\n",
"Epoch 31, Train Loss: 0.019241397384928458, Val Loss: 0.02124041018646155\n",
"Epoch 32, Train Loss: 0.01906546402464144, Val Loss: 0.021633521083797982\n",
"Epoch 33, Train Loss: 0.01884070100512302, Val Loss: 0.021043253979131357\n",
"Epoch 34, Train Loss: 0.01874133140855785, Val Loss: 0.02059999839472237\n",
"Epoch 35, Train Loss: 0.01853996916544851, Val Loss: 0.021178998303279947\n",
"Epoch 36, Train Loss: 0.018260161060412106, Val Loss: 0.020367807639178944\n",
"Epoch 37, Train Loss: 0.01830708983233956, Val Loss: 0.020017842692670536\n",
"Epoch 38, Train Loss: 0.018042967790675362, Val Loss: 0.020187884722071798\n",
"Epoch 39, Train Loss: 0.017922898732197056, Val Loss: 0.019615614786744118\n",
"Epoch 40, Train Loss: 0.017794321282236486, Val Loss: 0.019430582606191956\n",
"Epoch 41, Train Loss: 0.017688655022656517, Val Loss: 0.019477688401603875\n",
"Epoch 42, Train Loss: 0.017460078103512383, Val Loss: 0.018902005530448993\n",
"Epoch 43, Train Loss: 0.01727662416638441, Val Loss: 0.018832763184362382\n",
"Epoch 44, Train Loss: 0.017280888195599666, Val Loss: 0.019056980081124983\n",
"Epoch 45, Train Loss: 0.017114856775012312, Val Loss: 0.018604515495696174\n",
"Epoch 46, Train Loss: 0.016909640970858234, Val Loss: 0.018437264904157438\n",
"Epoch 47, Train Loss: 0.016691252999185946, Val Loss: 0.01889144025965413\n",
"Epoch 48, Train Loss: 0.016869753608079047, Val Loss: 0.018732781104965887\n",
"Epoch 49, Train Loss: 0.01653263871179243, Val Loss: 0.01850963812043418\n",
"Epoch 50, Train Loss: 0.01653244017520875, Val Loss: 0.0178856217344083\n",
"Epoch 51, Train Loss: 0.016499577624874823, Val Loss: 0.01781756227919415\n",
"Epoch 52, Train Loss: 0.016335643743249504, Val Loss: 0.01821571894323648\n",
"Epoch 53, Train Loss: 0.016375035212406415, Val Loss: 0.017511379168327176\n",
"Epoch 54, Train Loss: 0.016288986672428948, Val Loss: 0.017456448650849398\n",
"Epoch 55, Train Loss: 0.01623404509517137, Val Loss: 0.017827068525018978\n",
"Epoch 56, Train Loss: 0.016188283936615196, Val Loss: 0.017475027326883663\n",
"Epoch 57, Train Loss: 0.01605349867359588, Val Loss: 0.017256822728955033\n",
"Epoch 58, Train Loss: 0.015958637990610022, Val Loss: 0.017457437256712522\n",
"Epoch 59, Train Loss: 0.016034694237001774, Val Loss: 0.017437012713235705\n",
"Epoch 60, Train Loss: 0.0158486066956483, Val Loss: 0.017560158175096582\n",
"Epoch 61, Train Loss: 0.015632042563275286, Val Loss: 0.01692103194211846\n",
"Epoch 62, Train Loss: 0.015540152108608677, Val Loss: 0.01698271286632143\n",
"Epoch 63, Train Loss: 0.01545496231043025, Val Loss: 0.01699626362368242\n",
"Epoch 64, Train Loss: 0.015430795162488398, Val Loss: 0.01687317063559347\n",
"Epoch 65, Train Loss: 0.015489797350732191, Val Loss: 0.017046043955123248\n",
"Epoch 66, Train Loss: 0.015236956011682179, Val Loss: 0.0172197060214717\n",
"Epoch 67, Train Loss: 0.015348140916755895, Val Loss: 0.016508253249548265\n",
"Epoch 68, Train Loss: 0.015228347097519055, Val Loss: 0.016413471842212462\n",
"Epoch 69, Train Loss: 0.01516882229025997, Val Loss: 0.01686259738600521\n",
"Epoch 70, Train Loss: 0.015173258315593574, Val Loss: 0.01757873013726811\n",
"Epoch 71, Train Loss: 0.015156847716678986, Val Loss: 0.016662339123883353\n",
"Epoch 72, Train Loss: 0.015105586064507088, Val Loss: 0.016890839868183457\n",
"Epoch 73, Train Loss: 0.014925161955887051, Val Loss: 0.015931842709655194\n",
"Epoch 74, Train Loss: 0.014886363126497947, Val Loss: 0.016006485308840204\n",
"Epoch 75, Train Loss: 0.015015289608531735, Val Loss: 0.015968994154080526\n",
"Epoch 76, Train Loss: 0.014806462892968403, Val Loss: 0.015919692327838336\n",
"Epoch 77, Train Loss: 0.014728168116962653, Val Loss: 0.015852669684855797\n",
"Epoch 78, Train Loss: 0.014845167781319914, Val Loss: 0.016079049404543726\n",
"Epoch 79, Train Loss: 0.014719554133998435, Val Loss: 0.015957326447563387\n",
"Epoch 80, Train Loss: 0.014635249268281404, Val Loss: 0.015849308388780303\n",
"Epoch 81, Train Loss: 0.014474964379800849, Val Loss: 0.015526832887597049\n",
"Epoch 82, Train Loss: 0.014369143295641007, Val Loss: 0.015485089967277512\n",
"Epoch 83, Train Loss: 0.014446225396076743, Val Loss: 0.015848276135859204\n",
"Epoch 84, Train Loss: 0.014476079110537419, Val Loss: 0.015343323600158762\n",
"Epoch 85, Train Loss: 0.014672522836378174, Val Loss: 0.015515949938501885\n",
"Epoch 86, Train Loss: 0.014440825409545568, Val Loss: 0.015224411166203556\n",
"Epoch 87, Train Loss: 0.014462759978980111, Val Loss: 0.015663697370397512\n",
"Epoch 88, Train Loss: 0.01440465696262971, Val Loss: 0.015856551353944773\n",
"Epoch 89, Train Loss: 0.014255739579146559, Val Loss: 0.015246662380757616\n",
"Epoch 90, Train Loss: 0.014205876624202756, Val Loss: 0.015011716536732752\n",
"Epoch 91, Train Loss: 0.014259663818216924, Val Loss: 0.015085076996639593\n",
"Epoch 92, Train Loss: 0.014251617286978156, Val Loss: 0.015133185506756627\n",
"Epoch 93, Train Loss: 0.014119144052302723, Val Loss: 0.015415464166496227\n",
"Epoch 94, Train Loss: 0.014192042264053554, Val Loss: 0.015254960033986995\n",
"Epoch 95, Train Loss: 0.014140318196855094, Val Loss: 0.017451592276234235\n",
"Epoch 96, Train Loss: 0.014092271857890502, Val Loss: 0.015359595265072672\n",
"Epoch 97, Train Loss: 0.01409529693843574, Val Loss: 0.015055305060388437\n",
"Epoch 98, Train Loss: 0.014136464546688578, Val Loss: 0.015083992547953307\n",
"Epoch 99, Train Loss: 0.013914715411792103, Val Loss: 0.014718598477653604\n",
"Epoch 100, Train Loss: 0.013870610982518305, Val Loss: 0.01483591334588492\n",
"Test Loss: 0.010182651478874807\n"
]
}
],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 100\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, dataloader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f6b7999d190>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3U0lEQVR4nO3deXxU9b3/8fc5M8mEQBJAzIKGxQrIJjsYuD+XGkVUCtpaS7kFtNprC61KbSttpS4PjdWKtFblen0o11bEDdCLKwURkIiAxooLbkiokoBSEhKWJHO+vz/mzGSiBDIkOYcwr+fjcR4hM2f5zjch857POef7tYwxRgAAAD6x/W4AAABIboQRAADgK8IIAADwFWEEAAD4ijACAAB8RRgBAAC+IowAAABfEUYAAICvgn43oCkcx9EXX3yhjIwMWZbld3MAAEATGGO0Z88ede3aVbbdeP2jTYSRL774Qvn5+X43AwAAHIFt27bpxBNPbPT5NhFGMjIyJEVeTGZmps+tAQAATVFZWan8/PzY+3hj2kQYiZ6ayczMJIwAANDGHO4SCy5gBQAAvkoojBQVFWnEiBHKyMhQdna2Jk6cqM2bNx9ym/nz58uyrAZLWlpasxoNAACOHQmFkVdffVXTp0/X66+/rmXLlqm2tlbnnnuuqqurD7ldZmamtm/fHlu2bt3arEYDAIBjR0LXjLz44osNvp8/f76ys7O1ceNGnX766Y1uZ1mWcnNzj6yFAIBjUjgcVm1trd/NQDMEAgEFg8FmD7vRrAtYKyoqJEmdO3c+5HpVVVXq3r27HMfR0KFDddttt6l///6Nrn/gwAEdOHAg9n1lZWVzmgkAOMpUVVXpX//6l4wxfjcFzZSenq68vDylpqYe8T4sc4S/CY7j6Dvf+Y52796tNWvWNLpecXGxPvroI5166qmqqKjQn/70J61atUrvvvtuo/cc33jjjbrpppu+8XhFRQV30wBAGxcOh/XRRx8pPT1dxx9/PINZtlHGGNXU1Gjnzp0Kh8Pq1avXNwY2q6ysVFZW1mHfv484jPz0pz/VCy+8oDVr1hxyIJOvq62tVd++fTVp0iTdcsstB13nYJWR/Px8wggAHAP279+vLVu2qEePHmrXrp3fzUEz7d27V1u3blXPnj2/cYNKU8PIEZ2mmTFjhpYuXapVq1YlFEQkKSUlRUOGDNHHH3/c6DqhUEihUOhImgYAaCOoiBwbDjXMe5P3kcjKxhjNmDFDixcv1ooVK9SzZ8+EDxgOh/XOO+8oLy8v4W0BAMCxJ6HKyPTp07VgwQI988wzysjIUFlZmSQpKysrVmqbMmWKTjjhBBUVFUmSbr75Zp122mk6+eSTtXv3bt15553aunWrrrjiihZ+KQAAoC1KqDJy//33q6KiQmeeeaby8vJiy+OPPx5bp7S0VNu3b499/+9//1tXXnml+vbtq/PPP1+VlZVau3at+vXr13KvAgCANqRHjx6aO3dui+xr5cqVsixLu3fvbpH9+SGhykhTrnVduXJlg+/vvvtu3X333Qk1CgCAo82ZZ56pwYMHt0iIWL9+vdq3b9/8Rh0j2sREea3lwdWf6l//3qdJI7upT+6hZxQEAOBQjDEKh8MKBg//1nr88cd70KK2I6knynvune2av/Yzbf3q0MPZAwBajzFGe2vqfFmaOrrFtGnT9Oqrr+rPf/5zbJ616NxrL7zwgoYNG6ZQKKQ1a9bok08+0YQJE5STk6MOHTpoxIgR+sc//tFgf18/TWNZlh588EFddNFFSk9PV69evfTss88ecZ8+/fTT6t+/v0KhkHr06KG77rqrwfP33XefevXqpbS0NOXk5Oh73/te7LmnnnpKAwcOVLt27XTcccepsLDwsNO+NFdSV0aCduS2MocRAAHAN/tqw+o3+yVfjv3ezWOVnnr4t8I///nP+vDDDzVgwADdfPPNkqR3331XknT99dfrT3/6k0466SR16tRJ27Zt0/nnn69bb71VoVBIjzzyiMaPH6/NmzerW7dujR7jpptu0h133KE777xT99xzjyZPnqytW7cedpTzr9u4caO+//3v68Ybb9Sll16qtWvX6mc/+5mOO+44TZs2TRs2bNAvfvEL/e1vf9Po0aO1a9curV69WpK0fft2TZo0SXfccYcuuugi7dmzR6tXr271kXKTOozY7j3udQ5hBADQuKysLKWmpio9PT0219oHH3wgKXLX6DnnnBNbt3Pnzho0aFDs+1tuuUWLFy/Ws88+qxkzZjR6jGnTpmnSpEmSpNtuu01/+ctf9MYbb+i8885LqK1z5szR2WefrRtuuEGS1Lt3b7333nu68847NW3aNJWWlqp9+/a68MILlZGRoe7du2vIkCGSImGkrq5OF198sbp37y5JGjhwYELHPxJJHUaCgUgYCRNGAMA37VICeu/msb4du7mGDx/e4PuqqirdeOONeu6552Jv7vv27VNpaekh93PqqafG/t2+fXtlZmZqx44dCbfn/fff14QJExo8NmbMGM2dO1fhcFjnnHOOunfvrpNOOknnnXeezjvvvNjpoUGDBunss8/WwIEDNXbsWJ177rn63ve+p06dOiXcjkQk9TUj0coIYQQA/GNZltJTg74sLTEK7Nfvirnuuuu0ePFi3XbbbVq9erVKSko0cOBA1dTUHHI/KSkp3+gXx3Ga3b6vy8jI0JtvvqnHHntMeXl5mj17tgYNGqTdu3crEAho2bJleuGFF9SvXz/dc8896tOnj7Zs2dLi7YiX1GEkes0Ip2kAAIeTmpqqcDh82PVee+01TZs2TRdddJEGDhyo3NxcffbZZ63fQFffvn312muvfaNNvXv3ViAQqQQFg0EVFhbqjjvu0D//+U999tlnWrFihaRICBozZoxuuukmvfXWW0pNTdXixYtbtc1JfZom4I6n7xBGAACH0aNHD61bt06fffaZOnTo0GjVolevXlq0aJHGjx8vy7J0ww03tEqFozG//OUvNWLECN1yyy269NJLVVxcrL/+9a+67777JElLly7Vp59+qtNPP12dOnXS888/L8dx1KdPH61bt07Lly/Xueeeq+zsbK1bt047d+5U3759W7XNSV0ZCbivnsoIAOBwrrvuOgUCAfXr10/HH398o9eAzJkzR506ddLo0aM1fvx4jR07VkOHDvWsnUOHDtUTTzyhhQsXasCAAZo9e7ZuvvlmTZs2TZLUsWNHLVq0SN/+9rfVt29fzZs3T4899pj69++vzMxMrVq1Sueff7569+6t3//+97rrrrs0bty4Vm2zZVr7fp0W0NQpiBM1/dE39dw723XzhP6aUtCjxfYLAGjc/v37tWXLloNOOY+251A/z6a+fyd1ZcSOXjMSPurzGAAAx6ykDiPRC1i5mwYAcLS66qqr1KFDh4MuV111ld/NaxFJfgGrG0aO/jNVAIAkdfPNN+u666476HMteemCn5I7jDDOCADgKJedna3s7Gy/m9Gqkvo0TYARWAEA8F1yhxHmpgEAwHfJHUZiF7B6NxgNAABoKKnDSP3dND43BACAJJbUYYTKCAAA/iOMiMoIAMAbPXr00Ny5c5u0rmVZWrJkSau252hBGBGVEQAA/EQYEYOeAQDgp6QOIwwHDwBHAWOkmmp/lgQ+jD7wwAPq2rWrnK9V0ydMmKDLL79cn3zyiSZMmKCcnBx16NBBI0aM0D/+8Y8W66Z33nlH3/72t9WuXTsdd9xx+slPfqKqqqrY8ytXrtTIkSPVvn17dezYUWPGjNHWrVslSW+//bbOOussZWRkKDMzU8OGDdOGDRtarG3NldQjsDJRHgAcBWr3Srd19efYv/1CSm3fpFUvueQS/fznP9crr7yis88+W5K0a9cuvfjii3r++edVVVWl888/X7feeqtCoZAeeeQRjR8/Xps3b1a3bt2a1czq6mqNHTtWBQUFWr9+vXbs2KErrrhCM2bM0Pz581VXV6eJEyfqyiuv1GOPPaaamhq98cYbstzxtCZPnqwhQ4bo/vvvVyAQUElJiVJSUprVppaU1GEkyGkaAEATderUSePGjdOCBQtiYeSpp55Sly5ddNZ
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "8be48f80-a6e6-4b05-87ef-3adbf0bef576",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "cff8cba9-aba9-4347-8e1a-f169df8313c2",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" # tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
" # tr_mins = np.transpose(mins, (2, 0, 1))\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" eva_list.append([mae, rmse, mape, r2])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "edd09b0b-4496-4b88-a581-d1203aad05ce",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.669181</td>\n",
" <td>2.722375</td>\n",
" <td>0.228690</td>\n",
" <td>0.825025</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.101549</td>\n",
" <td>0.229373</td>\n",
" <td>0.027824</td>\n",
" <td>0.023960</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.456919</td>\n",
" <td>2.206495</td>\n",
" <td>0.147438</td>\n",
" <td>0.751642</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.600787</td>\n",
" <td>2.569844</td>\n",
" <td>0.210564</td>\n",
" <td>0.815437</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.663539</td>\n",
" <td>2.723380</td>\n",
" <td>0.228493</td>\n",
" <td>0.826285</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.726697</td>\n",
" <td>2.848122</td>\n",
" <td>0.248380</td>\n",
" <td>0.837574</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.998206</td>\n",
" <td>3.443690</td>\n",
" <td>0.287797</td>\n",
" <td>0.881901</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2\n",
"count 75.000000 75.000000 75.000000 75.000000\n",
"mean 1.669181 2.722375 0.228690 0.825025\n",
"std 0.101549 0.229373 0.027824 0.023960\n",
"min 1.456919 2.206495 0.147438 0.751642\n",
"25% 1.600787 2.569844 0.210564 0.815437\n",
"50% 1.663539 2.723380 0.228493 0.826285\n",
"75% 1.726697 2.848122 0.248380 0.837574\n",
"max 1.998206 3.443690 0.287797 0.881901"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "7e0a48b0-be9a-429b-a77f-3fe413c1aae7",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "1263f067-2d88-4321-900d-29aa2a84df12",
"metadata": {},
"outputs": [],
"source": [
"eva_list_frame = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * max_pixel_value\n",
" rev_recon = reconstructed * max_pixel_value\n",
" # todo: 这里需要只评估修补出来的模块\n",
" for i, sample in enumerate(rev_data):\n",
" used_mask = mask_rev[i]\n",
" data_label = sample[0] * used_mask\n",
" recon_no2 = rev_recon[i][0] * used_mask\n",
" data_label = data_label[used_mask==1]\n",
" recon_no2 = recon_no2[used_mask==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" r = np.corrcoef(data_label, recon_no2)[0, 1]\n",
" eva_list_frame.append([mae, rmse, mape, r2, ioa, r])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "27289a64-0405-48e3-bec3-ad0a612988a6",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" <td>4739.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.657513</td>\n",
" <td>2.352886</td>\n",
" <td>0.232260</td>\n",
" <td>0.406715</td>\n",
" <td>0.823672</td>\n",
" <td>0.747644</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.897659</td>\n",
" <td>1.318793</td>\n",
" <td>0.234080</td>\n",
" <td>0.877368</td>\n",
" <td>0.184708</td>\n",
" <td>0.191853</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.546354</td>\n",
" <td>0.695038</td>\n",
" <td>0.066870</td>\n",
" <td>-30.315991</td>\n",
" <td>-1.254103</td>\n",
" <td>-0.392216</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.042898</td>\n",
" <td>1.472388</td>\n",
" <td>0.137216</td>\n",
" <td>0.313886</td>\n",
" <td>0.782728</td>\n",
" <td>0.671206</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.465436</td>\n",
" <td>2.072718</td>\n",
" <td>0.174430</td>\n",
" <td>0.610684</td>\n",
" <td>0.879875</td>\n",
" <td>0.805015</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.976618</td>\n",
" <td>2.785021</td>\n",
" <td>0.234618</td>\n",
" <td>0.757136</td>\n",
" <td>0.929170</td>\n",
" <td>0.879403</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>9.007959</td>\n",
" <td>12.398485</td>\n",
" <td>3.290891</td>\n",
" <td>0.973600</td>\n",
" <td>0.993247</td>\n",
" <td>0.987535</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa \\\n",
"count 4739.000000 4739.000000 4739.000000 4739.000000 4739.000000 \n",
"mean 1.657513 2.352886 0.232260 0.406715 0.823672 \n",
"std 0.897659 1.318793 0.234080 0.877368 0.184708 \n",
"min 0.546354 0.695038 0.066870 -30.315991 -1.254103 \n",
"25% 1.042898 1.472388 0.137216 0.313886 0.782728 \n",
"50% 1.465436 2.072718 0.174430 0.610684 0.879875 \n",
"75% 1.976618 2.785021 0.234618 0.757136 0.929170 \n",
"max 9.007959 12.398485 3.290891 0.973600 0.993247 \n",
"\n",
" r \n",
"count 4739.000000 \n",
"mean 0.747644 \n",
"std 0.191853 \n",
"min -0.392216 \n",
"25% 0.671206 \n",
"50% 0.805015 \n",
"75% 0.879403 \n",
"max 0.987535 "
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "c72964bf-bbc5-4773-bd5f-6a0ea674934e",
"metadata": {},
"outputs": [],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2']).describe().to_csv('./eva_files/baseline_mask_loss.csv', encoding='utf-8-sig')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'data' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m visualize_feature(\u001b[43mdata\u001b[49m[\u001b[38;5;241m5\u001b[39m], masked_data[\u001b[38;5;241m5\u001b[39m], reconstructed[\u001b[38;5;241m5\u001b[39m], \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNO2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
]
}
],
"source": [
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}