update code
This commit is contained in:
commit
bfe09d5318
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,723 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3a2c33ed-8f78-4ce4-b5cd-7b7ffc5c8273",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cuda\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 设置CUDA设备\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Using device: {device}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "85459fd1-6835-41cd-b645-553611c358e8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Maximum pixel value in the dataset: 107.49169921875\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"max_pixel_value = 107.49169921875\n",
|
||||
"\n",
|
||||
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "3fc0918e-103c-40a3-93bc-6171e934a7e4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"checkpoint before Generator is OK\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class NO2Dataset(Dataset):\n",
|
||||
" \n",
|
||||
" def __init__(self, image_dir, mask_dir):\n",
|
||||
" \n",
|
||||
" self.image_dir = image_dir\n",
|
||||
" self.mask_dir = mask_dir\n",
|
||||
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
||||
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" \n",
|
||||
" return len(self.image_filenames)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" \n",
|
||||
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
||||
" mask_idx = np.random.choice(self.mask_filenames)\n",
|
||||
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
|
||||
"\n",
|
||||
" # 加载图像数据 (.npy 文件)\n",
|
||||
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 加载掩码数据 (.jpg 文件)\n",
|
||||
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
||||
"\n",
|
||||
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
||||
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
||||
"\n",
|
||||
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
||||
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 应用掩码\n",
|
||||
" masked_image = image.copy()\n",
|
||||
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
||||
"\n",
|
||||
" # cGAN的输入和目标\n",
|
||||
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
||||
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 转换形状为 (channels, height, width)\n",
|
||||
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
"\n",
|
||||
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
|
||||
"\n",
|
||||
"# 实例化数据集和数据加载器\n",
|
||||
"train_dir = './out_mat/96/train/'\n",
|
||||
"valid_dir = './out_mat/96/valid/'\n",
|
||||
"test_dir = './out_mat/96/test/'\n",
|
||||
"mask_dir = './out_mat/96/mask/20/'\n",
|
||||
"\n",
|
||||
"print(f\"checkpoint before Generator is OK\")\n",
|
||||
"\n",
|
||||
"dataset = NO2Dataset(train_dir, mask_dir)\n",
|
||||
"train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
|
||||
"\n",
|
||||
"validset = NO2Dataset(valid_dir, mask_dir)\n",
|
||||
"val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)\n",
|
||||
"\n",
|
||||
"testset = NO2Dataset(test_dir, mask_dir)\n",
|
||||
"test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "a60b7019-f231-4ccb-9195-c459f3a1521d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generator is on: cpu\n",
|
||||
"Discriminator is on: cpu\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 生成器模型\n",
|
||||
"class Generator(nn.Module):\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super(Generator, self).__init__()\n",
|
||||
" self.encoder = nn.Sequential(\n",
|
||||
" nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.BatchNorm2d(64),\n",
|
||||
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.BatchNorm2d(128),\n",
|
||||
" )\n",
|
||||
" self.decoder = nn.Sequential(\n",
|
||||
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.BatchNorm2d(64),\n",
|
||||
" nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.Tanh(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x, mask):\n",
|
||||
" x_encoded = self.encoder(x)\n",
|
||||
" x_decoded = self.decoder(x_encoded)\n",
|
||||
"\n",
|
||||
"# x_decoded = (x_decoded + 1) / 2\n",
|
||||
"\n",
|
||||
"# x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n",
|
||||
" return x_decoded\n",
|
||||
"\n",
|
||||
"# 判别器模型\n",
|
||||
"class Discriminator(nn.Module):\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super(Discriminator, self).__init__()\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.BatchNorm2d(64),\n",
|
||||
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.LeakyReLU(0.2, inplace=True),\n",
|
||||
" nn.BatchNorm2d(128),\n",
|
||||
" nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.model(x)\n",
|
||||
"\n",
|
||||
"# 将模型加载到GPU\n",
|
||||
"generator = Generator().to(device)\n",
|
||||
"discriminator = Discriminator().to(device)\n",
|
||||
"\n",
|
||||
"# 定义优化器和损失函数\n",
|
||||
"optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
||||
"optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
||||
"adversarial_loss = nn.BCELoss().to(device)\n",
|
||||
"\n",
|
||||
"# 确认模型是否在GPU上\n",
|
||||
"print(f\"Generator is on: {next(generator.parameters()).device}\")\n",
|
||||
"print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "fd2f0816-7301-4381-99e4-05905ce8b093",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def masked_mse_loss(preds, target, mask):\n",
|
||||
" loss = (preds - target) ** 2\n",
|
||||
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
||||
" loss = (loss * (1-mask)).sum() / (1-mask).sum() # 只计算被mask的像素点的损失\n",
|
||||
" return loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "edf26d05-7fd6-404b-8f9d-3d2054593f7b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generator.load_state_dict(torch.load('./models/GAN/generator-1d.pth'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "645dd325-fc70-4234-8279-bc8cbc4c5dde",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 开始训练\n",
|
||||
"epochs = 100\n",
|
||||
"for epoch in range(epochs):\n",
|
||||
" for i, (X, y, mask) in enumerate(train_loader):\n",
|
||||
" # 将数据移到 GPU 上\n",
|
||||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||||
" # print(f\"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}\") #checkpoint\n",
|
||||
" \n",
|
||||
" valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n",
|
||||
" fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n",
|
||||
"\n",
|
||||
" # 生成器生成图像\n",
|
||||
" optimizer_G.zero_grad()\n",
|
||||
" generated_images = generator(X, mask)\n",
|
||||
" g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * masked_mse_loss(\n",
|
||||
" generated_images, y, mask)\n",
|
||||
" g_loss.backward()\n",
|
||||
" optimizer_G.step()\n",
|
||||
"\n",
|
||||
" # 判别器训练\n",
|
||||
" optimizer_D.zero_grad()\n",
|
||||
" real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n",
|
||||
" fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n",
|
||||
" d_loss = 0.5 * (real_loss + fake_loss)\n",
|
||||
" d_loss.backward()\n",
|
||||
" optimizer_D.step()\n",
|
||||
"\n",
|
||||
" print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n",
|
||||
"\n",
|
||||
"# 保存训练好的模型\n",
|
||||
"torch.save(generator.state_dict(), './models/GAN/generator-1d.pth')\n",
|
||||
"torch.save(discriminator.state_dict(), './models/GAN/discriminator-1d.pth')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "9ab6849f-740a-49e4-9afd-23aafdc16725",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "8997686a-5812-4a92-972c-11376dfc1686",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def cal_ioa(y_true, y_pred):\n",
|
||||
" # 计算平均值\n",
|
||||
" mean_observed = np.mean(y_true)\n",
|
||||
" mean_predicted = np.mean(y_pred)\n",
|
||||
"\n",
|
||||
" # 计算IoA\n",
|
||||
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
||||
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
||||
" IoA = 1 - (numerator / denominator)\n",
|
||||
"\n",
|
||||
" return IoA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "78db8f70-3cdb-444f-8bb5-49126957a0b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eva_list = list()\n",
|
||||
"device = 'cpu'\n",
|
||||
"generator = generator.to(device)\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
||||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||||
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
||||
" reconstructed = generator(X, mask)\n",
|
||||
" rev_data = torch.squeeze(y * max_pixel_value, dim=1)\n",
|
||||
" rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)\n",
|
||||
" # todo: 这里需要只评估修补出来的模块\n",
|
||||
" data_label = rev_data * mask_rev\n",
|
||||
" data_label = data_label[mask_rev==1]\n",
|
||||
" recon_no2 = rev_recon * mask_rev\n",
|
||||
" recon_no2 = recon_no2[mask_rev==1]\n",
|
||||
" y_true = rev_data.flatten()\n",
|
||||
" y_pred = rev_recon.flatten()\n",
|
||||
" mae = mean_absolute_error(y_true, y_pred)\n",
|
||||
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
|
||||
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
|
||||
" r2 = r2_score(y_true, y_pred)\n",
|
||||
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
||||
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "ffc83338-9b0f-4934-9cba-406a5e1fb0ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "33eb7a86-242a-4488-9d9b-93548e72c98b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>mae</th>\n",
|
||||
" <th>rmse</th>\n",
|
||||
" <th>mape</th>\n",
|
||||
" <th>r2</th>\n",
|
||||
" <th>ioa</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>count</th>\n",
|
||||
" <td>75.000000</td>\n",
|
||||
" <td>75.000000</td>\n",
|
||||
" <td>75.000000</td>\n",
|
||||
" <td>75.000000</td>\n",
|
||||
" <td>75.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <td>1.685609</td>\n",
|
||||
" <td>2.824579</td>\n",
|
||||
" <td>0.223852</td>\n",
|
||||
" <td>0.807483</td>\n",
|
||||
" <td>0.894409</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>std</th>\n",
|
||||
" <td>0.520285</td>\n",
|
||||
" <td>0.613299</td>\n",
|
||||
" <td>0.066827</td>\n",
|
||||
" <td>0.107566</td>\n",
|
||||
" <td>0.024969</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>min</th>\n",
|
||||
" <td>1.108756</td>\n",
|
||||
" <td>2.040964</td>\n",
|
||||
" <td>0.143461</td>\n",
|
||||
" <td>0.336193</td>\n",
|
||||
" <td>0.812887</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>25%</th>\n",
|
||||
" <td>1.338143</td>\n",
|
||||
" <td>2.462648</td>\n",
|
||||
" <td>0.176170</td>\n",
|
||||
" <td>0.780906</td>\n",
|
||||
" <td>0.883027</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>50%</th>\n",
|
||||
" <td>1.509821</td>\n",
|
||||
" <td>2.608227</td>\n",
|
||||
" <td>0.206274</td>\n",
|
||||
" <td>0.850417</td>\n",
|
||||
" <td>0.900165</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>75%</th>\n",
|
||||
" <td>1.963103</td>\n",
|
||||
" <td>3.067560</td>\n",
|
||||
" <td>0.257667</td>\n",
|
||||
" <td>0.866705</td>\n",
|
||||
" <td>0.910917</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>max</th>\n",
|
||||
" <td>3.729434</td>\n",
|
||||
" <td>5.363288</td>\n",
|
||||
" <td>0.461465</td>\n",
|
||||
" <td>0.912240</td>\n",
|
||||
" <td>0.935183</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" mae rmse mape r2 ioa\n",
|
||||
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
|
||||
"mean 1.685609 2.824579 0.223852 0.807483 0.894409\n",
|
||||
"std 0.520285 0.613299 0.066827 0.107566 0.024969\n",
|
||||
"min 1.108756 2.040964 0.143461 0.336193 0.812887\n",
|
||||
"25% 1.338143 2.462648 0.176170 0.780906 0.883027\n",
|
||||
"50% 1.509821 2.608227 0.206274 0.850417 0.900165\n",
|
||||
"75% 1.963103 3.067560 0.257667 0.866705 0.910917\n",
|
||||
"max 3.729434 5.363288 0.461465 0.912240 0.935183"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 106,
|
||||
"id": "3844152b-c853-4311-a0a4-f57275ea78fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rst = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape', ascending=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 140,
|
||||
"id": "8f98d947-5a20-49f5-84a1-5ae5a0f4693e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>mae</th>\n",
|
||||
" <th>rmse</th>\n",
|
||||
" <th>mape</th>\n",
|
||||
" <th>r2</th>\n",
|
||||
" <th>ioa</th>\n",
|
||||
" <th>r</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>3523</th>\n",
|
||||
" <td>1.834664</td>\n",
|
||||
" <td>2.579296</td>\n",
|
||||
" <td>0.103911</td>\n",
|
||||
" <td>0.855056</td>\n",
|
||||
" <td>0.960841</td>\n",
|
||||
" <td>0.931334</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3544</th>\n",
|
||||
" <td>1.500194</td>\n",
|
||||
" <td>1.962885</td>\n",
|
||||
" <td>0.106816</td>\n",
|
||||
" <td>0.849688</td>\n",
|
||||
" <td>0.960970</td>\n",
|
||||
" <td>0.924187</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1952</th>\n",
|
||||
" <td>1.786639</td>\n",
|
||||
" <td>2.290560</td>\n",
|
||||
" <td>0.109383</td>\n",
|
||||
" <td>0.704122</td>\n",
|
||||
" <td>0.928829</td>\n",
|
||||
" <td>0.869446</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>602</th>\n",
|
||||
" <td>2.222957</td>\n",
|
||||
" <td>2.934734</td>\n",
|
||||
" <td>0.112751</td>\n",
|
||||
" <td>0.735178</td>\n",
|
||||
" <td>0.933639</td>\n",
|
||||
" <td>0.877028</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3531</th>\n",
|
||||
" <td>2.093165</td>\n",
|
||||
" <td>2.726698</td>\n",
|
||||
" <td>0.115755</td>\n",
|
||||
" <td>0.760530</td>\n",
|
||||
" <td>0.937662</td>\n",
|
||||
" <td>0.889606</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1114</th>\n",
|
||||
" <td>1.951748</td>\n",
|
||||
" <td>2.591448</td>\n",
|
||||
" <td>0.116578</td>\n",
|
||||
" <td>0.696970</td>\n",
|
||||
" <td>0.914501</td>\n",
|
||||
" <td>0.843026</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1979</th>\n",
|
||||
" <td>2.083001</td>\n",
|
||||
" <td>2.686231</td>\n",
|
||||
" <td>0.116762</td>\n",
|
||||
" <td>0.597512</td>\n",
|
||||
" <td>0.886877</td>\n",
|
||||
" <td>0.791842</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2568</th>\n",
|
||||
" <td>2.630587</td>\n",
|
||||
" <td>3.636890</td>\n",
|
||||
" <td>0.117044</td>\n",
|
||||
" <td>0.491952</td>\n",
|
||||
" <td>0.893928</td>\n",
|
||||
" <td>0.833221</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" mae rmse mape r2 ioa r\n",
|
||||
"3523 1.834664 2.579296 0.103911 0.855056 0.960841 0.931334\n",
|
||||
"3544 1.500194 1.962885 0.106816 0.849688 0.960970 0.924187\n",
|
||||
"1952 1.786639 2.290560 0.109383 0.704122 0.928829 0.869446\n",
|
||||
"602 2.222957 2.934734 0.112751 0.735178 0.933639 0.877028\n",
|
||||
"3531 2.093165 2.726698 0.115755 0.760530 0.937662 0.889606\n",
|
||||
"1114 1.951748 2.591448 0.116578 0.696970 0.914501 0.843026\n",
|
||||
"1979 2.083001 2.686231 0.116762 0.597512 0.886877 0.791842\n",
|
||||
"2568 2.630587 3.636890 0.117044 0.491952 0.893928 0.833221"
|
||||
]
|
||||
},
|
||||
"execution_count": 140,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"rst.head(8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 141,
|
||||
"id": "a72b4ec5-b72d-43d0-8178-954f8edcb0dd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'1114', '1952', '2568', '3523', '602'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 141,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
|
||||
"find_ex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 159,
|
||||
"id": "819950e1-d16d-42e9-a0f9-57878ebc8f89",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<Figure size 640x480 with 0 Axes>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for j in find_ex:\n",
|
||||
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
|
||||
" pred = np.load(f'./test_img/{j}-gan-recom.npy')[0]\n",
|
||||
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
|
||||
" plt.imshow(ori, cmap='RdYlGn_r')\n",
|
||||
" plt.gca().axis('off')\n",
|
||||
" plt.savefig(f'./test_img/out_fig/{j}-truth.png', bbox_inches='tight')\n",
|
||||
" plt.clf()\n",
|
||||
" \n",
|
||||
" plt.imshow(mask, cmap='gray')\n",
|
||||
" plt.gca().axis('off')\n",
|
||||
" plt.savefig(f'./test_img/out_fig/{j}-mask.png', bbox_inches='tight')\n",
|
||||
" plt.clf()\n",
|
||||
" \n",
|
||||
" mask_cp = np.where((1-mask) == 0, np.nan, (1-mask))\n",
|
||||
" plt.imshow(ori * mask_cp, cmap='RdYlGn_r')\n",
|
||||
" plt.gca().axis('off')\n",
|
||||
" plt.savefig(f'./test_img/out_fig/{j}-masked_ori.png', bbox_inches='tight')\n",
|
||||
" plt.clf()\n",
|
||||
" \n",
|
||||
" out = ori * mask + pred * (1 - mask)\n",
|
||||
" plt.imshow(out, cmap='RdYlGn_r')\n",
|
||||
" plt.gca().axis('off')\n",
|
||||
" plt.savefig(f'./test_img/out_fig/{j}-gan_out.png', bbox_inches='tight')\n",
|
||||
" plt.clf()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2926f4b3-0a15-48ae-bdaa-4e290e11b49d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,589 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import cv2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "adf69eb9-bedb-4db7-87c4-04c23752a7c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_data(pix, use_type='train'):\n",
|
||||
" datasets = list()\n",
|
||||
" file_list = [x for x in os.listdir(f\"./out_mat/{pix}/{use_type}/\") if x.endswith('.npy')]\n",
|
||||
" for file in file_list:\n",
|
||||
" file_img = np.load(f\"./out_mat/{pix}/{use_type}/{file}\")[:,:,:7]\n",
|
||||
" datasets.append(file_img)\n",
|
||||
" return np.asarray(datasets)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e0aa628f-37b7-498a-94d7-81241c20b305",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_set = load_data(96, 'train')\n",
|
||||
"val_set = load_data(96, 'valid')\n",
|
||||
"test_set = load_data(96, 'test')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "5d5f95cb-f40c-4ead-96fe-241068408b98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_mask(mask_rate):\n",
|
||||
" mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}')\n",
|
||||
" masks = list()\n",
|
||||
" for file in mask_files:\n",
|
||||
" d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE)\n",
|
||||
" d = (d > 0) * 1\n",
|
||||
" masks.append(d)\n",
|
||||
" return np.asarray(masks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "71452a77-8158-46b2-aecf-400ad7b72df5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks = load_mask(20)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "1902e0f8-32bb-4376-8238-334260b12623",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"maxs = train_set.max(axis=0)\n",
|
||||
"mins = train_set.min(axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8df9f3c3-ced8-4640-af30-b2f147dbdc96",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"26749"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(train_set)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "53664b12-fd95-4dd0-b61d-20682f8f14f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"norm_train = (train_set - mins) / (maxs-mins)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "05cb9dc8-c1df-48bf-a9dd-d084ce1d2068",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del train_set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "4ae39364-4cf6-49e9-b99f-6723520943b5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"norm_valid = (val_set - mins) / (maxs-mins)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "7f78b981-d079-4000-ba9f-d862e34903b1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del val_set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "f54eede6-e95a-4476-b822-79846c0b1079",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"norm_test = (test_set - mins) / (maxs-mins)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "e66887eb-df5e-46d3-b9c5-73af1272b27a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del test_set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "00afa8cd-18b4-4d71-8cab-fd140058dca3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(26749, 96, 96)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"norm_train.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31d91072-3878-4e3c-b6f1-09f597faf60d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trans_train = np.transpose(norm_train, (0, 3, 1, 2))\n",
|
||||
"trans_val = np.transpose(norm_valid, (0, 3, 1, 2))\n",
|
||||
"trans_test = np.transpose(norm_test, (0, 3, 1, 2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 可视化特定特征的函数\n",
|
||||
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
||||
" plt.figure(figsize=(12, 6))\n",
|
||||
" plt.subplot(1, 3, 1)\n",
|
||||
" plt.imshow(input_feature[0].cpu().numpy())\n",
|
||||
" plt.title(title + \" Input\")\n",
|
||||
" plt.subplot(1, 3, 2)\n",
|
||||
" plt.imshow(masked_feature[0].cpu().numpy())\n",
|
||||
" plt.title(title + \" Masked\")\n",
|
||||
" plt.subplot(1, 3, 3)\n",
|
||||
" plt.imshow(output_feature[0].detach().cpu().numpy())\n",
|
||||
" plt.title(title + \" Recovery\")\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 设置随机种子以确保结果的可重复性\n",
|
||||
"torch.manual_seed(0)\n",
|
||||
"np.random.seed(0)\n",
|
||||
"\n",
|
||||
"# 数据准备\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(device)\n",
|
||||
"# 将numpy数组转换为PyTorch张量\n",
|
||||
"tensor_train = torch.tensor(trans_train.astype(np.float32), device=device)\n",
|
||||
"tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device)\n",
|
||||
"tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1569baeb-5a9e-48c1-a735-82d0cba8ad29",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 创建一个数据集和数据加载器\n",
|
||||
"train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器\n",
|
||||
"val_set = TensorDataset(tensor_valid, tensor_valid)\n",
|
||||
"test_set = TensorDataset(tensor_test, tensor_test)\n",
|
||||
"batch_size = 64\n",
|
||||
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
|
||||
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n",
|
||||
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3c81785d-f0e6-486f-8aad-dba81d2ec146",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def mask_data(data, device, masks):\n",
|
||||
" mask_inds = np.random.choice(masks.shape[0], data.shape[0])\n",
|
||||
" mask = torch.from_numpy(masks[mask_inds]).to(device)\n",
|
||||
" tmp_first_channel = data[:, 0, :, :] * mask\n",
|
||||
" masked_data = torch.clone(data)\n",
|
||||
" masked_data[:, 0, :, :] = tmp_first_channel\n",
|
||||
" return masked_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SEBlock(nn.Module):\n",
|
||||
" def __init__(self, in_channels, reduced_dim):\n",
|
||||
" super(SEBlock, self).__init__()\n",
|
||||
" self.se = nn.Sequential(\n",
|
||||
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
||||
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
||||
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return x * self.se(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 定义Masked Autoencoder模型\n",
|
||||
"class MaskedAutoencoder(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MaskedAutoencoder, self).__init__()\n",
|
||||
" self.encoder = nn.Sequential(\n",
|
||||
" nn.Conv2d(7, 32, kernel_size=3, stride=2, padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" SEBlock(128, 128)\n",
|
||||
" )\n",
|
||||
" self.decoder = nn.Sequential(\n",
|
||||
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.ConvTranspose2d(16, 7, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||||
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" encoded = self.encoder(x)\n",
|
||||
" decoded = self.decoder(encoded)\n",
|
||||
" return decoded\n",
|
||||
"\n",
|
||||
"# 实例化模型、损失函数和优化器\n",
|
||||
"model = MaskedAutoencoder()\n",
|
||||
"criterion = nn.MSELoss()\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 训练函数\n",
|
||||
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
||||
" model.train()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||||
" masked_data = mask_data(data, device, masks)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" reconstructed = model(masked_data)\n",
|
||||
" loss = criterion(reconstructed, data)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" running_loss += loss.item()\n",
|
||||
" return running_loss / (batch_idx + 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 评估函数\n",
|
||||
"def evaluate(model, device, data_loader, criterion):\n",
|
||||
" model.eval()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||||
" data = data.to(device)\n",
|
||||
" masked_data = mask_data(data, device, masks)\n",
|
||||
" reconstructed = model(masked_data)\n",
|
||||
" if batch_idx == 8:\n",
|
||||
" rand_ind = np.random.randint(0, len(data))\n",
|
||||
" visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
||||
" loss = criterion(reconstructed, data)\n",
|
||||
" running_loss += loss.item()\n",
|
||||
" return running_loss / (batch_idx + 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a1847c78-cbc6-4560-bb49-4dc6e9b8bbd0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 测试函数\n",
|
||||
"def test(model, device, data_loader):\n",
|
||||
" model.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||||
" data = data.to(device)\n",
|
||||
" masked_data = mask_data(data, device, masks)\n",
|
||||
" masked_ind = np.argwhere(masked_data[0][0]==0)\n",
|
||||
" reconstructed = model(masked_data)\n",
|
||||
" recon_no2 = reconstructed[0][0]\n",
|
||||
" ori_no2 = data[0][0]\n",
|
||||
" return"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = model.to(device)\n",
|
||||
"\n",
|
||||
"num_epochs = 100\n",
|
||||
"train_losses = list()\n",
|
||||
"val_losses = list()\n",
|
||||
"for epoch in range(num_epochs):\n",
|
||||
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
|
||||
" train_losses.append(train_loss)\n",
|
||||
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
||||
" val_losses.append(val_loss)\n",
|
||||
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
||||
"\n",
|
||||
"# 测试模型\n",
|
||||
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
||||
"print(f'Test Loss: {test_loss}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tr_ind = list(range(len(train_losses)))\n",
|
||||
"val_ind = list(range(len(val_losses)))\n",
|
||||
"plt.plot(train_losses, label='train_loss')\n",
|
||||
"plt.plot(val_losses, label='val_loss')\n",
|
||||
"plt.legend(loc='best')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cff8cba9-aba9-4347-8e1a-f169df8313c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with torch.no_grad():\n",
|
||||
" device = 'cpu'\n",
|
||||
" for batch_idx, (data, _) in enumerate(test_loader):\n",
|
||||
" model = model.to(device)\n",
|
||||
" data = data.to(device)\n",
|
||||
" masked_data = mask_data(data, device, masks)\n",
|
||||
" reconstructed = model(masked_data)\n",
|
||||
" tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
|
||||
" tr_mins = np.transpose(mins, (2, 0, 1))\n",
|
||||
" rev_data = data * (tr_maxs - tr_mins) + tr_mins\n",
|
||||
" rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n",
|
||||
" data_label = ((rev_data!=0) * (masked_data==0) * rev_data)[:, 0]\n",
|
||||
" recon_no2 = ((rev_data!=0) * (masked_data==0) * rev_recon)[:, 0]\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "565f2e0a-1689-4a03-9fc1-15519b1cdaee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"real = data_label.flatten()\n",
|
||||
"pred = recon_no2.flatten()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1e8f71e-855a-41ea-b62f-095514af66a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0a6eea29-cd3e-4712-ad73-589bcf7b88be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mean_squared_error(real, pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a85feefb-aa3a-4bb9-86ac-7cc6938a47e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mean_absolute_percentage_error(real, pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c2bfda87-3de8-4a06-969f-d346f4447cf6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"r2_score(real, pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d1955fc0-490a-40d5-8b3c-dd6e5beed235",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mean_absolute_error(real, pred)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"62"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len('The total $R^2$ for under 40\\% missing data test set was 0.88.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "215938c7-d514-48e7-a460-088dcd7927ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,608 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "fa295d87-946f-402b-9d97-1127ee9a33a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"MAX_VALUE = 107.49169921875"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "c6dd8e35-02e3-491c-b4be-a874cf1054ba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"device(type='cuda')"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"device"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "2f151caf-43d1-4d59-a111-96ad5e6bc38b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class GrayScaleDataset(Dataset):\n",
|
||||
" def __init__(self, data_dir):\n",
|
||||
" self.data_dir = data_dir\n",
|
||||
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.file_list)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
|
||||
" data = np.load(file_path)[:,:,0] / MAX_VALUE\n",
|
||||
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class NO2Dataset(Dataset):\n",
|
||||
" \n",
|
||||
" def __init__(self, image_dir, mask_dir):\n",
|
||||
" \n",
|
||||
" self.image_dir = image_dir\n",
|
||||
" self.mask_dir = mask_dir\n",
|
||||
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
|
||||
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" \n",
|
||||
" return len(self.image_filenames)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" \n",
|
||||
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
|
||||
" mask_idx = idx % len(self.mask_filenames)\n",
|
||||
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
|
||||
"\n",
|
||||
" # 加载图像数据 (.npy 文件)\n",
|
||||
" image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 加载掩码数据 (.jpg 文件)\n",
|
||||
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
|
||||
"\n",
|
||||
" # 将掩码数据中非0值设为1,0值保持不变\n",
|
||||
" mask = np.where(mask != 0, 1.0, 0.0)\n",
|
||||
"\n",
|
||||
" # 保持掩码数据形状为 (96, 96, 1)\n",
|
||||
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 应用掩码\n",
|
||||
" masked_image = image.copy()\n",
|
||||
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
|
||||
"\n",
|
||||
" # cGAN的输入和目标\n",
|
||||
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
|
||||
" y = image[:, :, 0:1] # 目标输出为NO2数据,形状为 (96, 96, 1)\n",
|
||||
"\n",
|
||||
" # 转换形状为 (channels, height, width)\n",
|
||||
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
|
||||
"\n",
|
||||
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "36752a6d-329a-464d-a329-f02206bf63b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PatchMasking:\n",
|
||||
" def __init__(self, patch_size, mask_ratio):\n",
|
||||
" self.patch_size = patch_size\n",
|
||||
" self.mask_ratio = mask_ratio\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" batch_size, C, H, W = x.shape\n",
|
||||
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
|
||||
" num_masked = int(num_patches * self.mask_ratio)\n",
|
||||
" \n",
|
||||
" # 为每个样本生成独立的mask\n",
|
||||
" masks = []\n",
|
||||
" for _ in range(batch_size):\n",
|
||||
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
|
||||
" mask[:num_masked] = 1\n",
|
||||
" mask = mask[torch.randperm(num_patches)]\n",
|
||||
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
|
||||
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
|
||||
" masks.append(mask)\n",
|
||||
" \n",
|
||||
" # 将所有mask堆叠成一个批量张量\n",
|
||||
" masks = torch.stack(masks, dim=0)\n",
|
||||
" masks = torch.unsqueeze(masks, dim=1)\n",
|
||||
" \n",
|
||||
" # 应用mask到输入x上\n",
|
||||
" masked_x = x * (1 - masks.float())\n",
|
||||
" return masked_x, masks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "0db0d920-8de2-4bad-9b99-67eed152644d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Mlp(nn.Module):\n",
|
||||
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
||||
" super().__init__()\n",
|
||||
" out_features = out_features or in_features\n",
|
||||
" hidden_features = hidden_features or in_features\n",
|
||||
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
||||
"\n",
|
||||
" self.act = act_layer()\n",
|
||||
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
||||
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc1(x)\n",
|
||||
" x = self.act(x)\n",
|
||||
" x = self.drop(x)\n",
|
||||
" x = self.fc2(x)\n",
|
||||
" x = self.drop(x)\n",
|
||||
" return x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "cb27d3a7-77ed-4110-96bd-bcc4880964d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ViTEncoder(nn.Module):\n",
|
||||
" def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):\n",
|
||||
" super(ViTEncoder, self).__init__()\n",
|
||||
" self.patch_size = patch_size\n",
|
||||
" self.dim = dim\n",
|
||||
" self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n",
|
||||
" \n",
|
||||
" # 定义 Transformer 编码器层\n",
|
||||
" encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)\n",
|
||||
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.patch_embedding(x)\n",
|
||||
" x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n",
|
||||
" x = self.transformer_encoder(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class ConvDecoder(nn.Module):\n",
|
||||
" def __init__(self, dim=128, patch_size=8, img_size=96):\n",
|
||||
" super(ConvDecoder, self).__init__()\n",
|
||||
" self.dim = dim\n",
|
||||
" self.patch_size = patch_size\n",
|
||||
" self.img_size = img_size\n",
|
||||
" self.decoder = nn.Sequential(\n",
|
||||
" nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n",
|
||||
" x = self.decoder(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class MAEModel(nn.Module):\n",
|
||||
" def __init__(self, encoder, decoder):\n",
|
||||
" super(MAEModel, self).__init__()\n",
|
||||
" self.encoder = encoder\n",
|
||||
" self.decoder = decoder\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" encoded = self.encoder(x)\n",
|
||||
" decoded = self.decoder(encoded)\n",
|
||||
" return decoded"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "783e62af-7f6a-40bd-a423-be63fe98a655",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def masked_mse_loss(preds, target, mask):\n",
|
||||
" loss = (preds - target) ** 2\n",
|
||||
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
|
||||
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
|
||||
" return loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "baeffdf0-cdc2-44c4-972a-e2e671635d6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
|
||||
" model.to(device)\n",
|
||||
" for epoch in range(epochs):\n",
|
||||
" model.train()\n",
|
||||
" train_loss = 0\n",
|
||||
" for data in train_loader:\n",
|
||||
" data = data.to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
|
||||
" output = model(masked_data)\n",
|
||||
" loss = masked_mse_loss(output, data, mask)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" train_loss += loss.item()\n",
|
||||
" train_loss /= len(train_loader)\n",
|
||||
"\n",
|
||||
" model.eval()\n",
|
||||
" val_loss = 0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for data in val_loader:\n",
|
||||
" data = data.to(device)\n",
|
||||
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
|
||||
" output = model(masked_data)\n",
|
||||
" loss = masked_mse_loss(output, data, mask)\n",
|
||||
" val_loss += loss.item()\n",
|
||||
" val_loss /= len(val_loader)\n",
|
||||
"\n",
|
||||
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dir = './out_mat/96/train/'\n",
|
||||
"train_dataset = GrayScaleDataset(train_dir)\n",
|
||||
"\n",
|
||||
"val_dir = './out_mat/96/valid/'\n",
|
||||
"val_dataset = GrayScaleDataset(val_dir)\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
|
||||
"val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "7d6d07a4-31f1-4350-a487-b583db979381",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"encoder = ViTEncoder()\n",
|
||||
"decoder = ConvDecoder()\n",
|
||||
"model = MAEModel(encoder, decoder)\n",
|
||||
"criterion = nn.MSELoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "8ee33651-f5f0-4b92-96e9-a84e32725b44",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([128, 128, 6, 6])"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a.transpose(1, 2).reshape(-1, 128, 6, 6).shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "a5684758-bc6d-45b0-b885-da37820ca5ac",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m 2\u001b[0m a \u001b[38;5;241m=\u001b[39m encoder(i)\n\u001b[0;32m----> 3\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m c \u001b[38;5;241m=\u001b[39m decoder(b)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"Cell \u001b[0;32mIn[12], line 13\u001b[0m, in \u001b[0;36mMlp.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 13\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(x)\n\u001b[1;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrop(x)\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 454\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 455\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in train_loader:\n",
|
||||
" a = encoder(i)\n",
|
||||
" b = model.mlp(a)\n",
|
||||
" c = decoder(b)\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "09b04e16-3257-4890-b736-a6c7274561e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"train_model(model, train_loader, val_loader, epochs=100, criterion=criterion, optimizer=optimizer, device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
|
||||
"test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "56653f37-899a-47d6-8d50-e456b4ad1835",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "f1ecbd05-7aa3-43ae-8bc2-aa44d19689b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def cal_ioa(y_true, y_pred):\n",
|
||||
" # 计算平均值\n",
|
||||
" mean_observed = np.mean(y_true)\n",
|
||||
" mean_predicted = np.mean(y_pred)\n",
|
||||
"\n",
|
||||
" # 计算IoA\n",
|
||||
" numerator = np.sum((y_true - y_pred) ** 2)\n",
|
||||
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
|
||||
" IoA = 1 - (numerator / denominator)\n",
|
||||
"\n",
|
||||
" return IoA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "e840b789-bf68-4b4d-a8d3-c5362c310349",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eva_list = list()\n",
|
||||
"device = 'cpu'\n",
|
||||
"model = model.to(device)\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
|
||||
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
|
||||
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
|
||||
" reconstructed = model(X)\n",
|
||||
" rev_data = y * MAX_VALUE\n",
|
||||
" rev_recon = reconstructed * MAX_VALUE\n",
|
||||
" # todo: 这里需要只评估修补出来的模块\n",
|
||||
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
||||
" data_label = data_label[mask_rev==1]\n",
|
||||
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
||||
" recon_no2 = recon_no2[mask_rev==1]\n",
|
||||
" mae = mean_absolute_error(data_label, recon_no2)\n",
|
||||
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
|
||||
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
|
||||
" r2 = r2_score(data_label, recon_no2)\n",
|
||||
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
|
||||
" eva_list.append([mae, rmse, mape, r2, ioa])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "41fa754d-1eee-43a2-9e39-a0254719be30",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>mae</th>\n",
|
||||
" <th>rmse</th>\n",
|
||||
" <th>mape</th>\n",
|
||||
" <th>r2</th>\n",
|
||||
" <th>ioa</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>count</th>\n",
|
||||
" <td>149.000000</td>\n",
|
||||
" <td>149.000000</td>\n",
|
||||
" <td>149.000000</td>\n",
|
||||
" <td>149.000000</td>\n",
|
||||
" <td>149.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <td>7.068207</td>\n",
|
||||
" <td>9.016465</td>\n",
|
||||
" <td>0.814727</td>\n",
|
||||
" <td>-0.952793</td>\n",
|
||||
" <td>0.564749</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>std</th>\n",
|
||||
" <td>0.659118</td>\n",
|
||||
" <td>0.774556</td>\n",
|
||||
" <td>0.054147</td>\n",
|
||||
" <td>0.162851</td>\n",
|
||||
" <td>0.033048</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>min</th>\n",
|
||||
" <td>5.609327</td>\n",
|
||||
" <td>7.113544</td>\n",
|
||||
" <td>0.599120</td>\n",
|
||||
" <td>-1.402735</td>\n",
|
||||
" <td>0.461420</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>25%</th>\n",
|
||||
" <td>6.613351</td>\n",
|
||||
" <td>8.499699</td>\n",
|
||||
" <td>0.782008</td>\n",
|
||||
" <td>-1.049951</td>\n",
|
||||
" <td>0.544980</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>50%</th>\n",
|
||||
" <td>7.086443</td>\n",
|
||||
" <td>9.045812</td>\n",
|
||||
" <td>0.811261</td>\n",
|
||||
" <td>-0.938765</td>\n",
|
||||
" <td>0.567080</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>75%</th>\n",
|
||||
" <td>7.495309</td>\n",
|
||||
" <td>9.530408</td>\n",
|
||||
" <td>0.848900</td>\n",
|
||||
" <td>-0.849266</td>\n",
|
||||
" <td>0.586134</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>max</th>\n",
|
||||
" <td>8.663801</td>\n",
|
||||
" <td>10.995004</td>\n",
|
||||
" <td>0.984343</td>\n",
|
||||
" <td>-0.591799</td>\n",
|
||||
" <td>0.630479</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" mae rmse mape r2 ioa\n",
|
||||
"count 149.000000 149.000000 149.000000 149.000000 149.000000\n",
|
||||
"mean 7.068207 9.016465 0.814727 -0.952793 0.564749\n",
|
||||
"std 0.659118 0.774556 0.054147 0.162851 0.033048\n",
|
||||
"min 5.609327 7.113544 0.599120 -1.402735 0.461420\n",
|
||||
"25% 6.613351 8.499699 0.782008 -1.049951 0.544980\n",
|
||||
"50% 7.086443 9.045812 0.811261 -0.938765 0.567080\n",
|
||||
"75% 7.495309 9.530408 0.848900 -0.849266 0.586134\n",
|
||||
"max 8.663801 10.995004 0.984343 -0.591799 0.630479"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b15bbdc-cb87-4648-b22f-72917b8c1e6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue