update code

This commit is contained in:
zhaojinghao 2024-11-21 14:02:33 +08:00
commit bfe09d5318
22 changed files with 20114 additions and 0 deletions

527
build_data.ipynb Normal file

File diff suppressed because one or more lines are too long

723
build_gan-1d.ipynb Normal file
View File

@ -0,0 +1,723 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3a2c33ed-8f78-4ce4-b5cd-7b7ffc5c8273",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"import os\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"os.environ[\"CUDA_VISIBLE_DEVICE\"] = \"0\" \n",
"\n",
"\n",
"# 设置CUDA设备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "85459fd1-6835-41cd-b645-553611c358e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum pixel value in the dataset: 107.49169921875\n"
]
}
],
"source": [
"max_pixel_value = 107.49169921875\n",
"\n",
"print(f\"Maximum pixel value in the dataset: {max_pixel_value}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3fc0918e-103c-40a3-93bc-6171e934a7e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint before Generator is OK\n"
]
}
],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = np.random.choice(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, mask_idx)\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / max_pixel_value # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)\n",
"\n",
"# 实例化数据集和数据加载器\n",
"train_dir = './out_mat/96/train/'\n",
"valid_dir = './out_mat/96/valid/'\n",
"test_dir = './out_mat/96/test/'\n",
"mask_dir = './out_mat/96/mask/20/'\n",
"\n",
"print(f\"checkpoint before Generator is OK\")\n",
"\n",
"dataset = NO2Dataset(train_dir, mask_dir)\n",
"train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)\n",
"\n",
"validset = NO2Dataset(valid_dir, mask_dir)\n",
"val_loader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)\n",
"\n",
"testset = NO2Dataset(test_dir, mask_dir)\n",
"test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=8)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "a60b7019-f231-4ccb-9195-c459f3a1521d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generator is on: cpu\n",
"Discriminator is on: cpu\n"
]
}
],
"source": [
"# 生成器模型\n",
"class Generator(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(128),\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),\n",
" nn.Tanh(),\n",
" )\n",
"\n",
" def forward(self, x, mask):\n",
" x_encoded = self.encoder(x)\n",
" x_decoded = self.decoder(x_encoded)\n",
"\n",
"# x_decoded = (x_decoded + 1) / 2\n",
"\n",
"# x_output = (1 - mask) * x_decoded + mask * x[:, :1, :, :]\n",
" return x_decoded\n",
"\n",
"# 判别器模型\n",
"class Discriminator(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Conv2d(2, 64, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.BatchNorm2d(128),\n",
" nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
"# 将模型加载到GPU\n",
"generator = Generator().to(device)\n",
"discriminator = Discriminator().to(device)\n",
"\n",
"# 定义优化器和损失函数\n",
"optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
"optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
"adversarial_loss = nn.BCELoss().to(device)\n",
"\n",
"# 确认模型是否在GPU上\n",
"print(f\"Generator is on: {next(generator.parameters()).device}\")\n",
"print(f\"Discriminator is on: {next(discriminator.parameters()).device}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fd2f0816-7301-4381-99e4-05905ce8b093",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * (1-mask)).sum() / (1-mask).sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "edf26d05-7fd6-404b-8f9d-3d2054593f7b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generator.load_state_dict(torch.load('./models/GAN/generator-1d.pth'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "645dd325-fc70-4234-8279-bc8cbc4c5dde",
"metadata": {},
"outputs": [],
"source": [
"# 开始训练\n",
"epochs = 100\n",
"for epoch in range(epochs):\n",
" for i, (X, y, mask) in enumerate(train_loader):\n",
" # 将数据移到 GPU 上\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" # print(f\"X is on: {X.device}, y is on: {y.device}, mask is on: {mask.device}, i = {i}\") #checkpoint\n",
" \n",
" valid = torch.ones((X.size(0), 1, 12, 12)).to(device)\n",
" fake = torch.zeros((X.size(0), 1, 12, 12)).to(device)\n",
"\n",
" # 生成器生成图像\n",
" optimizer_G.zero_grad()\n",
" generated_images = generator(X, mask)\n",
" g_loss = adversarial_loss(discriminator(torch.cat((generated_images, X), dim=1)), valid) + 100 * masked_mse_loss(\n",
" generated_images, y, mask)\n",
" g_loss.backward()\n",
" optimizer_G.step()\n",
"\n",
" # 判别器训练\n",
" optimizer_D.zero_grad()\n",
" real_loss = adversarial_loss(discriminator(torch.cat((y, X), dim=1)), valid)\n",
" fake_loss = adversarial_loss(discriminator(torch.cat((generated_images.detach(), X), dim=1)), fake)\n",
" d_loss = 0.5 * (real_loss + fake_loss)\n",
" d_loss.backward()\n",
" optimizer_D.step()\n",
"\n",
" print(f\"Epoch [{epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]\")\n",
"\n",
"# 保存训练好的模型\n",
"torch.save(generator.state_dict(), './models/GAN/generator-1d.pth')\n",
"torch.save(discriminator.state_dict(), './models/GAN/discriminator-1d.pth')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9ab6849f-740a-49e4-9afd-23aafdc16725",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8997686a-5812-4a92-972c-11376dfc1686",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "78db8f70-3cdb-444f-8bb5-49126957a0b6",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"generator = generator.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = generator(X, mask)\n",
" rev_data = torch.squeeze(y * max_pixel_value, dim=1)\n",
" rev_recon = torch.squeeze(reconstructed * max_pixel_value, dim=1)\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = rev_data * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = rev_recon * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" y_true = rev_data.flatten()\n",
" y_pred = rev_recon.flatten()\n",
" mae = mean_absolute_error(y_true, y_pred)\n",
" rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
" mape = mean_absolute_percentage_error(y_true, y_pred)\n",
" r2 = r2_score(y_true, y_pred)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ffc83338-9b0f-4934-9cba-406a5e1fb0ca",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "33eb7a86-242a-4488-9d9b-93548e72c98b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" <td>75.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>1.685609</td>\n",
" <td>2.824579</td>\n",
" <td>0.223852</td>\n",
" <td>0.807483</td>\n",
" <td>0.894409</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.520285</td>\n",
" <td>0.613299</td>\n",
" <td>0.066827</td>\n",
" <td>0.107566</td>\n",
" <td>0.024969</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.108756</td>\n",
" <td>2.040964</td>\n",
" <td>0.143461</td>\n",
" <td>0.336193</td>\n",
" <td>0.812887</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>1.338143</td>\n",
" <td>2.462648</td>\n",
" <td>0.176170</td>\n",
" <td>0.780906</td>\n",
" <td>0.883027</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>1.509821</td>\n",
" <td>2.608227</td>\n",
" <td>0.206274</td>\n",
" <td>0.850417</td>\n",
" <td>0.900165</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.963103</td>\n",
" <td>3.067560</td>\n",
" <td>0.257667</td>\n",
" <td>0.866705</td>\n",
" <td>0.910917</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>3.729434</td>\n",
" <td>5.363288</td>\n",
" <td>0.461465</td>\n",
" <td>0.912240</td>\n",
" <td>0.935183</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa\n",
"count 75.000000 75.000000 75.000000 75.000000 75.000000\n",
"mean 1.685609 2.824579 0.223852 0.807483 0.894409\n",
"std 0.520285 0.613299 0.066827 0.107566 0.024969\n",
"min 1.108756 2.040964 0.143461 0.336193 0.812887\n",
"25% 1.338143 2.462648 0.176170 0.780906 0.883027\n",
"50% 1.509821 2.608227 0.206274 0.850417 0.900165\n",
"75% 1.963103 3.067560 0.257667 0.866705 0.910917\n",
"max 3.729434 5.363288 0.461465 0.912240 0.935183"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "3844152b-c853-4311-a0a4-f57275ea78fc",
"metadata": {},
"outputs": [],
"source": [
"rst = pd.DataFrame(eva_list_frame, columns=['mae', 'rmse', 'mape', 'r2', 'ioa', 'r']).sort_values(by='mape', ascending=True)"
]
},
{
"cell_type": "code",
"execution_count": 140,
"id": "8f98d947-5a20-49f5-84a1-5ae5a0f4693e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" <th>r</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>3523</th>\n",
" <td>1.834664</td>\n",
" <td>2.579296</td>\n",
" <td>0.103911</td>\n",
" <td>0.855056</td>\n",
" <td>0.960841</td>\n",
" <td>0.931334</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3544</th>\n",
" <td>1.500194</td>\n",
" <td>1.962885</td>\n",
" <td>0.106816</td>\n",
" <td>0.849688</td>\n",
" <td>0.960970</td>\n",
" <td>0.924187</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1952</th>\n",
" <td>1.786639</td>\n",
" <td>2.290560</td>\n",
" <td>0.109383</td>\n",
" <td>0.704122</td>\n",
" <td>0.928829</td>\n",
" <td>0.869446</td>\n",
" </tr>\n",
" <tr>\n",
" <th>602</th>\n",
" <td>2.222957</td>\n",
" <td>2.934734</td>\n",
" <td>0.112751</td>\n",
" <td>0.735178</td>\n",
" <td>0.933639</td>\n",
" <td>0.877028</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3531</th>\n",
" <td>2.093165</td>\n",
" <td>2.726698</td>\n",
" <td>0.115755</td>\n",
" <td>0.760530</td>\n",
" <td>0.937662</td>\n",
" <td>0.889606</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1114</th>\n",
" <td>1.951748</td>\n",
" <td>2.591448</td>\n",
" <td>0.116578</td>\n",
" <td>0.696970</td>\n",
" <td>0.914501</td>\n",
" <td>0.843026</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1979</th>\n",
" <td>2.083001</td>\n",
" <td>2.686231</td>\n",
" <td>0.116762</td>\n",
" <td>0.597512</td>\n",
" <td>0.886877</td>\n",
" <td>0.791842</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2568</th>\n",
" <td>2.630587</td>\n",
" <td>3.636890</td>\n",
" <td>0.117044</td>\n",
" <td>0.491952</td>\n",
" <td>0.893928</td>\n",
" <td>0.833221</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa r\n",
"3523 1.834664 2.579296 0.103911 0.855056 0.960841 0.931334\n",
"3544 1.500194 1.962885 0.106816 0.849688 0.960970 0.924187\n",
"1952 1.786639 2.290560 0.109383 0.704122 0.928829 0.869446\n",
"602 2.222957 2.934734 0.112751 0.735178 0.933639 0.877028\n",
"3531 2.093165 2.726698 0.115755 0.760530 0.937662 0.889606\n",
"1114 1.951748 2.591448 0.116578 0.696970 0.914501 0.843026\n",
"1979 2.083001 2.686231 0.116762 0.597512 0.886877 0.791842\n",
"2568 2.630587 3.636890 0.117044 0.491952 0.893928 0.833221"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rst.head(8)"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "a72b4ec5-b72d-43d0-8178-954f8edcb0dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'1114', '1952', '2568', '3523', '602'}"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"find_ex = set([x.split('-')[0].strip() for x in os.listdir('./test_img/') if 'npy' in x])\n",
"find_ex"
]
},
{
"cell_type": "code",
"execution_count": 159,
"id": "819950e1-d16d-42e9-a0f9-57878ebc8f89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for j in find_ex:\n",
" ori = np.load(f'./test_img/{j}-real.npy')[0]\n",
" pred = np.load(f'./test_img/{j}-gan-recom.npy')[0]\n",
" mask = np.load(f'./test_img/{j}-mask.npy')\n",
" plt.imshow(ori, cmap='RdYlGn_r')\n",
" plt.gca().axis('off')\n",
" plt.savefig(f'./test_img/out_fig/{j}-truth.png', bbox_inches='tight')\n",
" plt.clf()\n",
" \n",
" plt.imshow(mask, cmap='gray')\n",
" plt.gca().axis('off')\n",
" plt.savefig(f'./test_img/out_fig/{j}-mask.png', bbox_inches='tight')\n",
" plt.clf()\n",
" \n",
" mask_cp = np.where((1-mask) == 0, np.nan, (1-mask))\n",
" plt.imshow(ori * mask_cp, cmap='RdYlGn_r')\n",
" plt.gca().axis('off')\n",
" plt.savefig(f'./test_img/out_fig/{j}-masked_ori.png', bbox_inches='tight')\n",
" plt.clf()\n",
" \n",
" out = ori * mask + pred * (1 - mask)\n",
" plt.imshow(out, cmap='RdYlGn_r')\n",
" plt.gca().axis('off')\n",
" plt.savefig(f'./test_img/out_fig/{j}-gan_out.png', bbox_inches='tight')\n",
" plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2926f4b3-0a15-48ae-bdaa-4e290e11b49d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

869
build_gan.ipynb Normal file

File diff suppressed because one or more lines are too long

277
build_ppt.ipynb Normal file

File diff suppressed because one or more lines are too long

1010
torch_GAN_1d_baseline.ipynb Normal file

File diff suppressed because one or more lines are too long

589
torch_MAE.ipynb Normal file
View File

@ -0,0 +1,589 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "adf69eb9-bedb-4db7-87c4-04c23752a7c3",
"metadata": {},
"outputs": [],
"source": [
"def load_data(pix, use_type='train'):\n",
" datasets = list()\n",
" file_list = [x for x in os.listdir(f\"./out_mat/{pix}/{use_type}/\") if x.endswith('.npy')]\n",
" for file in file_list:\n",
" file_img = np.load(f\"./out_mat/{pix}/{use_type}/{file}\")[:,:,:7]\n",
" datasets.append(file_img)\n",
" return np.asarray(datasets)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e0aa628f-37b7-498a-94d7-81241c20b305",
"metadata": {},
"outputs": [],
"source": [
"train_set = load_data(96, 'train')\n",
"val_set = load_data(96, 'valid')\n",
"test_set = load_data(96, 'test')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5d5f95cb-f40c-4ead-96fe-241068408b98",
"metadata": {},
"outputs": [],
"source": [
"def load_mask(mask_rate):\n",
" mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}')\n",
" masks = list()\n",
" for file in mask_files:\n",
" d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE)\n",
" d = (d > 0) * 1\n",
" masks.append(d)\n",
" return np.asarray(masks)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "71452a77-8158-46b2-aecf-400ad7b72df5",
"metadata": {},
"outputs": [],
"source": [
"masks = load_mask(20)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1902e0f8-32bb-4376-8238-334260b12623",
"metadata": {},
"outputs": [],
"source": [
"maxs = train_set.max(axis=0)\n",
"mins = train_set.min(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8df9f3c3-ced8-4640-af30-b2f147dbdc96",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"26749"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_set)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "53664b12-fd95-4dd0-b61d-20682f8f14f4",
"metadata": {},
"outputs": [],
"source": [
"norm_train = (train_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "05cb9dc8-c1df-48bf-a9dd-d084ce1d2068",
"metadata": {},
"outputs": [],
"source": [
"del train_set"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4ae39364-4cf6-49e9-b99f-6723520943b5",
"metadata": {},
"outputs": [],
"source": [
"norm_valid = (val_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7f78b981-d079-4000-ba9f-d862e34903b1",
"metadata": {},
"outputs": [],
"source": [
"del val_set"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f54eede6-e95a-4476-b822-79846c0b1079",
"metadata": {},
"outputs": [],
"source": [
"norm_test = (test_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e66887eb-df5e-46d3-b9c5-73af1272b27a",
"metadata": {},
"outputs": [],
"source": [
"del test_set"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "00afa8cd-18b4-4d71-8cab-fd140058dca3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(26749, 96, 96)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"norm_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31d91072-3878-4e3c-b6f1-09f597faf60d",
"metadata": {},
"outputs": [],
"source": [
"trans_train = np.transpose(norm_train, (0, 3, 1, 2))\n",
"trans_val = np.transpose(norm_valid, (0, 3, 1, 2))\n",
"trans_test = np.transpose(norm_test, (0, 3, 1, 2))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy())\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy())\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy())\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
"metadata": {},
"outputs": [],
"source": [
"# 设置随机种子以确保结果的可重复性\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)\n",
"\n",
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"# 将numpy数组转换为PyTorch张量\n",
"tensor_train = torch.tensor(trans_train.astype(np.float32), device=device)\n",
"tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device)\n",
"tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1569baeb-5a9e-48c1-a735-82d0cba8ad29",
"metadata": {},
"outputs": [],
"source": [
"# 创建一个数据集和数据加载器\n",
"train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器\n",
"val_set = TensorDataset(tensor_valid, tensor_valid)\n",
"test_set = TensorDataset(tensor_test, tensor_test)\n",
"batch_size = 64\n",
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n",
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c81785d-f0e6-486f-8aad-dba81d2ec146",
"metadata": {},
"outputs": [],
"source": [
"def mask_data(data, device, masks):\n",
" mask_inds = np.random.choice(masks.shape[0], data.shape[0])\n",
" mask = torch.from_numpy(masks[mask_inds]).to(device)\n",
" tmp_first_channel = data[:, 0, :, :] * mask\n",
" masked_data = torch.clone(data)\n",
" masked_data[:, 0, :, :] = tmp_first_channel\n",
" return masked_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(7, 32, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" SEBlock(128, 128)\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(16, 7, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" masked_data = mask_data(data, device, masks)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(masked_data)\n",
" loss = criterion(reconstructed, data)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" reconstructed = model(masked_data)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(data))\n",
" visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = criterion(reconstructed, data)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1847c78-cbc6-4560-bb49-4dc6e9b8bbd0",
"metadata": {},
"outputs": [],
"source": [
"# 测试函数\n",
"def test(model, device, data_loader):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" masked_ind = np.argwhere(masked_data[0][0]==0)\n",
" reconstructed = model(masked_data)\n",
" recon_no2 = reconstructed[0][0]\n",
" ori_no2 = data[0][0]\n",
" return"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 100\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cff8cba9-aba9-4347-8e1a-f169df8313c2",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" device = 'cpu'\n",
" for batch_idx, (data, _) in enumerate(test_loader):\n",
" model = model.to(device)\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" reconstructed = model(masked_data)\n",
" tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
" tr_mins = np.transpose(mins, (2, 0, 1))\n",
" rev_data = data * (tr_maxs - tr_mins) + tr_mins\n",
" rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n",
" data_label = ((rev_data!=0) * (masked_data==0) * rev_data)[:, 0]\n",
" recon_no2 = ((rev_data!=0) * (masked_data==0) * rev_recon)[:, 0]\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "565f2e0a-1689-4a03-9fc1-15519b1cdaee",
"metadata": {},
"outputs": [],
"source": [
"real = data_label.flatten()\n",
"pred = recon_no2.flatten()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1e8f71e-855a-41ea-b62f-095514af66a3",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a6eea29-cd3e-4712-ad73-589bcf7b88be",
"metadata": {},
"outputs": [],
"source": [
"mean_squared_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a85feefb-aa3a-4bb9-86ac-7cc6938a47e8",
"metadata": {},
"outputs": [],
"source": [
"mean_absolute_percentage_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2bfda87-3de8-4a06-969f-d346f4447cf6",
"metadata": {},
"outputs": [],
"source": [
"r2_score(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1955fc0-490a-40d5-8b3c-dd6e5beed235",
"metadata": {},
"outputs": [],
"source": [
"mean_absolute_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
"metadata": {},
"outputs": [],
"source": [
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"62"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len('The total $R^2$ for under 40\\% missing data test set was 0.88.')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "215938c7-d514-48e7-a460-088dcd7927ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

1276
torch_MAE_1d_ViT-Copy1.ipynb Normal file

File diff suppressed because it is too large Load Diff

608
torch_MAE_1d_ViT.ipynb Normal file
View File

@ -0,0 +1,608 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "fa295d87-946f-402b-9d97-1127ee9a33a0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset, random_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"from PIL import Image\n",
"\n",
"MAX_VALUE = 107.49169921875"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c6dd8e35-02e3-491c-b4be-a874cf1054ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2f151caf-43d1-4d59-a111-96ad5e6bc38b",
"metadata": {},
"outputs": [],
"source": [
"class GrayScaleDataset(Dataset):\n",
" def __init__(self, data_dir):\n",
" self.data_dir = data_dir\n",
" self.file_list = [x for x in os.listdir(data_dir) if x.endswith('npy')]\n",
"\n",
" def __len__(self):\n",
" return len(self.file_list)\n",
"\n",
" def __getitem__(self, idx):\n",
" file_path = os.path.join(self.data_dir, self.file_list[idx])\n",
" data = np.load(file_path)[:,:,0] / MAX_VALUE\n",
" return torch.tensor(data, dtype=torch.float32).unsqueeze(0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3ecd7bd0-15a0-4420-95e1-066e4d023cd3",
"metadata": {},
"outputs": [],
"source": [
"class NO2Dataset(Dataset):\n",
" \n",
" def __init__(self, image_dir, mask_dir):\n",
" \n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.npy')] # 仅加载 .npy 文件\n",
" self.mask_filenames = [f for f in os.listdir(mask_dir) if f.endswith('.jpg')] # 仅加载 .jpg 文件\n",
" \n",
" def __len__(self):\n",
" \n",
" return len(self.image_filenames)\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" image_path = os.path.join(self.image_dir, self.image_filenames[idx])\n",
" mask_idx = idx % len(self.mask_filenames)\n",
" mask_path = os.path.join(self.mask_dir, self.mask_filenames[mask_idx])\n",
"\n",
" # 加载图像数据 (.npy 文件)\n",
" image = np.load(image_path).astype(np.float32)[:,:,:1] / MAX_VALUE # 形状为 (96, 96, 1)\n",
"\n",
" # 加载掩码数据 (.jpg 文件)\n",
" mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)\n",
"\n",
" # 将掩码数据中非0值设为10值保持不变\n",
" mask = np.where(mask != 0, 1.0, 0.0)\n",
"\n",
" # 保持掩码数据形状为 (96, 96, 1)\n",
" mask = mask[:, :, np.newaxis] # 将形状调整为 (96, 96, 1)\n",
"\n",
" # 应用掩码\n",
" masked_image = image.copy()\n",
" masked_image[:, :, 0] = image[:, :, 0] * mask.squeeze() # 遮盖NO2数据\n",
"\n",
" # cGAN的输入和目标\n",
" X = masked_image[:, :, :1] # 形状为 (96, 96, 8)\n",
" y = image[:, :, 0:1] # 目标输出为NO2数据形状为 (96, 96, 1)\n",
"\n",
" # 转换形状为 (channels, height, width)\n",
" X = np.transpose(X, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" y = np.transpose(y, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
" mask = np.transpose(mask, (2, 0, 1)) # 转换为 (1, 96, 96)\n",
"\n",
" return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "36752a6d-329a-464d-a329-f02206bf63b0",
"metadata": {},
"outputs": [],
"source": [
"class PatchMasking:\n",
" def __init__(self, patch_size, mask_ratio):\n",
" self.patch_size = patch_size\n",
" self.mask_ratio = mask_ratio\n",
"\n",
" def __call__(self, x):\n",
" batch_size, C, H, W = x.shape\n",
" num_patches = (H // self.patch_size) * (W // self.patch_size)\n",
" num_masked = int(num_patches * self.mask_ratio)\n",
" \n",
" # 为每个样本生成独立的mask\n",
" masks = []\n",
" for _ in range(batch_size):\n",
" mask = torch.zeros(num_patches, dtype=torch.bool, device=x.device)\n",
" mask[:num_masked] = 1\n",
" mask = mask[torch.randperm(num_patches)]\n",
" mask = mask.view(H // self.patch_size, W // self.patch_size)\n",
" mask = mask.repeat_interleave(self.patch_size, dim=0).repeat_interleave(self.patch_size, dim=1)\n",
" masks.append(mask)\n",
" \n",
" # 将所有mask堆叠成一个批量张量\n",
" masks = torch.stack(masks, dim=0)\n",
" masks = torch.unsqueeze(masks, dim=1)\n",
" \n",
" # 应用mask到输入x上\n",
" masked_x = x * (1 - masks.float())\n",
" return masked_x, masks"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0db0d920-8de2-4bad-9b99-67eed152644d",
"metadata": {},
"outputs": [],
"source": [
"class Mlp(nn.Module):\n",
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
" super().__init__()\n",
" out_features = out_features or in_features\n",
" hidden_features = hidden_features or in_features\n",
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
"\n",
" self.act = act_layer()\n",
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
" self.drop = nn.Dropout(drop, inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.act(x)\n",
" x = self.drop(x)\n",
" x = self.fc2(x)\n",
" x = self.drop(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cb27d3a7-77ed-4110-96bd-bcc4880964d2",
"metadata": {},
"outputs": [],
"source": [
"class ViTEncoder(nn.Module):\n",
" def __init__(self, img_size=96, patch_size=8, dim=128, depth=4, heads=4, mlp_dim=256):\n",
" super(ViTEncoder, self).__init__()\n",
" self.patch_size = patch_size\n",
" self.dim = dim\n",
" self.patch_embedding = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)\n",
" \n",
" # 定义 Transformer 编码器层\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)\n",
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)\n",
"\n",
" def forward(self, x):\n",
" x = self.patch_embedding(x)\n",
" x = x.flatten(2).transpose(1, 2) # 形状变为 (batch_size, num_patches, dim)\n",
" x = self.transformer_encoder(x)\n",
" return x\n",
"\n",
"class ConvDecoder(nn.Module):\n",
" def __init__(self, dim=128, patch_size=8, img_size=96):\n",
" super(ConvDecoder, self).__init__()\n",
" self.dim = dim\n",
" self.patch_size = patch_size\n",
" self.img_size = img_size\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(dim, 128, kernel_size=patch_size, stride=patch_size),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(128, 1, kernel_size=3, stride=1, padding=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" # x = x.transpose(1, 2).view(-1, self.dim, self.img_size // self.patch_size, self.img_size // self.patch_size)\n",
" x = self.decoder(x)\n",
" return x\n",
"\n",
"class MAEModel(nn.Module):\n",
" def __init__(self, encoder, decoder):\n",
" super(MAEModel, self).__init__()\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "783e62af-7f6a-40bd-a423-be63fe98a655",
"metadata": {},
"outputs": [],
"source": [
"def masked_mse_loss(preds, target, mask):\n",
" loss = (preds - target) ** 2\n",
" loss = loss.mean(dim=-1) # 对每个像素点求平均\n",
" loss = (loss * mask).sum() / mask.sum() # 只计算被mask的像素点的损失\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "baeffdf0-cdc2-44c4-972a-e2e671635d6a",
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):\n",
" model.to(device)\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for data in train_loader:\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()\n",
" train_loss /= len(train_loader)\n",
"\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for data in val_loader:\n",
" data = data.to(device)\n",
" masked_data, mask = PatchMasking(patch_size=16, mask_ratio=0.2)(data)\n",
" output = model(masked_data)\n",
" loss = masked_mse_loss(output, data, mask)\n",
" val_loss += loss.item()\n",
" val_loss /= len(val_loader)\n",
"\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bb524f86-aa7d-44ee-b13e-b9ba4e5b3a0b",
"metadata": {},
"outputs": [],
"source": [
"train_dir = './out_mat/96/train/'\n",
"train_dataset = GrayScaleDataset(train_dir)\n",
"\n",
"val_dir = './out_mat/96/valid/'\n",
"val_dataset = GrayScaleDataset(val_dir)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "7d6d07a4-31f1-4350-a487-b583db979381",
"metadata": {},
"outputs": [],
"source": [
"encoder = ViTEncoder()\n",
"decoder = ConvDecoder()\n",
"model = MAEModel(encoder, decoder)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8ee33651-f5f0-4b92-96e9-a84e32725b44",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([128, 128, 6, 6])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.transpose(1, 2).reshape(-1, 128, 6, 6).shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a5684758-bc6d-45b0-b885-da37820ca5ac",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m 2\u001b[0m a \u001b[38;5;241m=\u001b[39m encoder(i)\n\u001b[0;32m----> 3\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m c \u001b[38;5;241m=\u001b[39m decoder(b)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"Cell \u001b[0;32mIn[12], line 13\u001b[0m, in \u001b[0;36mMlp.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 13\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(x)\n\u001b[1;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdrop(x)\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 454\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 455\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [256, 128, 1, 1], expected input[1, 32, 144, 128] to have 128 channels, but got 32 channels instead"
]
}
],
"source": [
"for i in train_loader:\n",
" a = encoder(i)\n",
" b = model.mlp(a)\n",
" c = decoder(b)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09b04e16-3257-4890-b736-a6c7274561e0",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"train_model(model, train_loader, val_loader, epochs=100, criterion=criterion, optimizer=optimizer, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b0c5cf4b-aca2-4781-8b47-bf2a46269635",
"metadata": {},
"outputs": [],
"source": [
"test_set = NO2Dataset('./out_mat/96/test/', './out_mat/96/mask/20/')\n",
"test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "56653f37-899a-47d6-8d50-e456b4ad1835",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f1ecbd05-7aa3-43ae-8bc2-aa44d19689b9",
"metadata": {},
"outputs": [],
"source": [
"def cal_ioa(y_true, y_pred):\n",
" # 计算平均值\n",
" mean_observed = np.mean(y_true)\n",
" mean_predicted = np.mean(y_pred)\n",
"\n",
" # 计算IoA\n",
" numerator = np.sum((y_true - y_pred) ** 2)\n",
" denominator = 2 * np.sum((np.abs(y_true - mean_observed) + np.abs(y_pred - mean_predicted)) ** 2)\n",
" IoA = 1 - (numerator / denominator)\n",
"\n",
" return IoA"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e840b789-bf68-4b4d-a8d3-c5362c310349",
"metadata": {},
"outputs": [],
"source": [
"eva_list = list()\n",
"device = 'cpu'\n",
"model = model.to(device)\n",
"with torch.no_grad():\n",
" for batch_idx, (X, y, mask) in enumerate(test_loader):\n",
" X, y, mask = X.to(device), y.to(device), mask.to(device)\n",
" mask_rev = (torch.squeeze(mask, dim=1)==0) * 1 # mask取反获得修复区域\n",
" reconstructed = model(X)\n",
" rev_data = y * MAX_VALUE\n",
" rev_recon = reconstructed * MAX_VALUE\n",
" # todo: 这里需要只评估修补出来的模块\n",
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
" data_label = data_label[mask_rev==1]\n",
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
" recon_no2 = recon_no2[mask_rev==1]\n",
" mae = mean_absolute_error(data_label, recon_no2)\n",
" rmse = np.sqrt(mean_squared_error(data_label, recon_no2))\n",
" mape = mean_absolute_percentage_error(data_label, recon_no2)\n",
" r2 = r2_score(data_label, recon_no2)\n",
" ioa = cal_ioa(data_label.detach().numpy(), recon_no2.detach().numpy())\n",
" eva_list.append([mae, rmse, mape, r2, ioa])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "41fa754d-1eee-43a2-9e39-a0254719be30",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mae</th>\n",
" <th>rmse</th>\n",
" <th>mape</th>\n",
" <th>r2</th>\n",
" <th>ioa</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" <td>149.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>7.068207</td>\n",
" <td>9.016465</td>\n",
" <td>0.814727</td>\n",
" <td>-0.952793</td>\n",
" <td>0.564749</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.659118</td>\n",
" <td>0.774556</td>\n",
" <td>0.054147</td>\n",
" <td>0.162851</td>\n",
" <td>0.033048</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>5.609327</td>\n",
" <td>7.113544</td>\n",
" <td>0.599120</td>\n",
" <td>-1.402735</td>\n",
" <td>0.461420</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>6.613351</td>\n",
" <td>8.499699</td>\n",
" <td>0.782008</td>\n",
" <td>-1.049951</td>\n",
" <td>0.544980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>7.086443</td>\n",
" <td>9.045812</td>\n",
" <td>0.811261</td>\n",
" <td>-0.938765</td>\n",
" <td>0.567080</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>7.495309</td>\n",
" <td>9.530408</td>\n",
" <td>0.848900</td>\n",
" <td>-0.849266</td>\n",
" <td>0.586134</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>8.663801</td>\n",
" <td>10.995004</td>\n",
" <td>0.984343</td>\n",
" <td>-0.591799</td>\n",
" <td>0.630479</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mae rmse mape r2 ioa\n",
"count 149.000000 149.000000 149.000000 149.000000 149.000000\n",
"mean 7.068207 9.016465 0.814727 -0.952793 0.564749\n",
"std 0.659118 0.774556 0.054147 0.162851 0.033048\n",
"min 5.609327 7.113544 0.599120 -1.402735 0.461420\n",
"25% 6.613351 8.499699 0.782008 -1.049951 0.544980\n",
"50% 7.086443 9.045812 0.811261 -0.938765 0.567080\n",
"75% 7.495309 9.530408 0.848900 -0.849266 0.586134\n",
"max 8.663801 10.995004 0.984343 -0.591799 0.630479"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(eva_list, columns=['mae', 'rmse', 'mape', 'r2', 'ioa']).describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b15bbdc-cb87-4648-b22f-72917b8c1e6b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

895
torch_MAE_1d_baseline.ipynb Normal file

File diff suppressed because one or more lines are too long

1048
torch_MAE_1d_decoder.ipynb Normal file

File diff suppressed because one or more lines are too long

982
torch_MAE_1d_encoder.ipynb Normal file

File diff suppressed because one or more lines are too long

1068
torch_MAE_1d_final.ipynb Normal file

File diff suppressed because one or more lines are too long

1093
torch_MAE_1d_final_10.ipynb Normal file

File diff suppressed because one or more lines are too long

1297
torch_MAE_1d_final_20.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

943
torch_MAE_1d_final_30.ipynb Normal file

File diff suppressed because one or more lines are too long

957
torch_MAE_1d_final_40.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

425
未命名1.ipynb Normal file

File diff suppressed because one or more lines are too long

1064
论文绘图.ipynb Normal file

File diff suppressed because one or more lines are too long