{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, TensorDataset, random_split\n", "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import cv2" ] }, { "cell_type": "code", "execution_count": 1, "id": "adf69eb9-bedb-4db7-87c4-04c23752a7c3", "metadata": {}, "outputs": [], "source": [ "def load_data(pix, use_type='train'):\n", " datasets = list()\n", " file_list = [x for x in os.listdir(f\"./out_mat/{pix}/{use_type}/\") if x.endswith('.npy')]\n", " for file in file_list:\n", " file_img = np.load(f\"./out_mat/{pix}/{use_type}/{file}\")[:,:,:7]\n", " datasets.append(file_img)\n", " return np.asarray(datasets)" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0aa628f-37b7-498a-94d7-81241c20b305", "metadata": {}, "outputs": [], "source": [ "train_set = load_data(96, 'train')\n", "val_set = load_data(96, 'valid')\n", "test_set = load_data(96, 'test')" ] }, { "cell_type": "code", "execution_count": 5, "id": "5d5f95cb-f40c-4ead-96fe-241068408b98", "metadata": {}, "outputs": [], "source": [ "def load_mask(mask_rate):\n", " mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}')\n", " masks = list()\n", " for file in mask_files:\n", " d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE)\n", " d = (d > 0) * 1\n", " masks.append(d)\n", " return np.asarray(masks)" ] }, { "cell_type": "code", "execution_count": 6, "id": "71452a77-8158-46b2-aecf-400ad7b72df5", "metadata": {}, "outputs": [], "source": [ "masks = load_mask(20)" ] }, { "cell_type": "code", "execution_count": 7, "id": "1902e0f8-32bb-4376-8238-334260b12623", "metadata": {}, "outputs": [], "source": [ "maxs = train_set.max(axis=0)\n", "mins = train_set.min(axis=0)" ] }, { "cell_type": "code", "execution_count": 8, "id": "8df9f3c3-ced8-4640-af30-b2f147dbdc96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "26749" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_set)" ] }, { "cell_type": "code", "execution_count": 9, "id": "53664b12-fd95-4dd0-b61d-20682f8f14f4", "metadata": {}, "outputs": [], "source": [ "norm_train = (train_set - mins) / (maxs-mins)" ] }, { "cell_type": "code", "execution_count": 10, "id": "05cb9dc8-c1df-48bf-a9dd-d084ce1d2068", "metadata": {}, "outputs": [], "source": [ "del train_set" ] }, { "cell_type": "code", "execution_count": 11, "id": "4ae39364-4cf6-49e9-b99f-6723520943b5", "metadata": {}, "outputs": [], "source": [ "norm_valid = (val_set - mins) / (maxs-mins)" ] }, { "cell_type": "code", "execution_count": 12, "id": "7f78b981-d079-4000-ba9f-d862e34903b1", "metadata": {}, "outputs": [], "source": [ "del val_set" ] }, { "cell_type": "code", "execution_count": 13, "id": "f54eede6-e95a-4476-b822-79846c0b1079", "metadata": {}, "outputs": [], "source": [ "norm_test = (test_set - mins) / (maxs-mins)" ] }, { "cell_type": "code", "execution_count": 14, "id": "e66887eb-df5e-46d3-b9c5-73af1272b27a", "metadata": {}, "outputs": [], "source": [ "del test_set" ] }, { "cell_type": "code", "execution_count": 15, "id": "00afa8cd-18b4-4d71-8cab-fd140058dca3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(26749, 96, 96)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "norm_train.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "31d91072-3878-4e3c-b6f1-09f597faf60d", "metadata": {}, "outputs": [], "source": [ "trans_train = np.transpose(norm_train, (0, 3, 1, 2))\n", "trans_val = np.transpose(norm_valid, (0, 3, 1, 2))\n", "trans_test = np.transpose(norm_test, (0, 3, 1, 2))" ] }, { "cell_type": "code", "execution_count": 16, "id": "70797703-1619-4be7-b965-5506b3d1e775", "metadata": {}, "outputs": [], "source": [ "# 可视化特定特征的函数\n", "def visualize_feature(input_feature,masked_feature, output_feature, title):\n", " plt.figure(figsize=(12, 6))\n", " plt.subplot(1, 3, 1)\n", " plt.imshow(input_feature[0].cpu().numpy())\n", " plt.title(title + \" Input\")\n", " plt.subplot(1, 3, 2)\n", " plt.imshow(masked_feature[0].cpu().numpy())\n", " plt.title(title + \" Masked\")\n", " plt.subplot(1, 3, 3)\n", " plt.imshow(output_feature[0].detach().cpu().numpy())\n", " plt.title(title + \" Recovery\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "aeda3567-4c4d-496b-9570-9ae757b45e72", "metadata": {}, "outputs": [], "source": [ "# 设置随机种子以确保结果的可重复性\n", "torch.manual_seed(0)\n", "np.random.seed(0)\n", "\n", "# 数据准备\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)\n", "# 将numpy数组转换为PyTorch张量\n", "tensor_train = torch.tensor(trans_train.astype(np.float32), device=device)\n", "tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device)\n", "tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)" ] }, { "cell_type": "code", "execution_count": null, "id": "1569baeb-5a9e-48c1-a735-82d0cba8ad29", "metadata": {}, "outputs": [], "source": [ "# 创建一个数据集和数据加载器\n", "train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器\n", "val_set = TensorDataset(tensor_valid, tensor_valid)\n", "test_set = TensorDataset(tensor_test, tensor_test)\n", "batch_size = 64\n", "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n", "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "3c81785d-f0e6-486f-8aad-dba81d2ec146", "metadata": {}, "outputs": [], "source": [ "def mask_data(data, device, masks):\n", " mask_inds = np.random.choice(masks.shape[0], data.shape[0])\n", " mask = torch.from_numpy(masks[mask_inds]).to(device)\n", " tmp_first_channel = data[:, 0, :, :] * mask\n", " masked_data = torch.clone(data)\n", " masked_data[:, 0, :, :] = tmp_first_channel\n", " return masked_data" ] }, { "cell_type": "code", "execution_count": null, "id": "15b9d453-d8d9-43b8-aca2-904735fb3a99", "metadata": {}, "outputs": [], "source": [ "class SEBlock(nn.Module):\n", " def __init__(self, in_channels, reduced_dim):\n", " super(SEBlock, self).__init__()\n", " self.se = nn.Sequential(\n", " nn.AdaptiveAvgPool2d(1), # 全局平均池化\n", " nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n", " nn.ReLU(),\n", " nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n", " nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n", " )\n", "\n", " def forward(self, x):\n", " return x * self.se(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "6379adb7-8a87-4dd8-a695-4013a7b37830", "metadata": { "tags": [] }, "outputs": [], "source": [ "# 定义Masked Autoencoder模型\n", "class MaskedAutoencoder(nn.Module):\n", " def __init__(self):\n", " super(MaskedAutoencoder, self).__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Conv2d(7, 32, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(),\n", " SEBlock(128, 128)\n", " )\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(16, 7, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n", " )\n", "\n", " def forward(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "# 实例化模型、损失函数和优化器\n", "model = MaskedAutoencoder()\n", "criterion = nn.MSELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": null, "id": "404a8bfb-4976-4cce-b989-c5e401bce0d7", "metadata": {}, "outputs": [], "source": [ "# 训练函数\n", "def train_epoch(model, device, data_loader, criterion, optimizer):\n", " model.train()\n", " running_loss = 0.0\n", " for batch_idx, (data, _) in enumerate(data_loader):\n", " masked_data = mask_data(data, device, masks)\n", " optimizer.zero_grad()\n", " reconstructed = model(masked_data)\n", " loss = criterion(reconstructed, data)\n", " loss.backward()\n", " optimizer.step()\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16", "metadata": {}, "outputs": [], "source": [ "# 评估函数\n", "def evaluate(model, device, data_loader, criterion):\n", " model.eval()\n", " running_loss = 0.0\n", " with torch.no_grad():\n", " for batch_idx, (data, _) in enumerate(data_loader):\n", " data = data.to(device)\n", " masked_data = mask_data(data, device, masks)\n", " reconstructed = model(masked_data)\n", " if batch_idx == 8:\n", " rand_ind = np.random.randint(0, len(data))\n", " visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2')\n", " loss = criterion(reconstructed, data)\n", " running_loss += loss.item()\n", " return running_loss / (batch_idx + 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1847c78-cbc6-4560-bb49-4dc6e9b8bbd0", "metadata": {}, "outputs": [], "source": [ "# 测试函数\n", "def test(model, device, data_loader):\n", " model.eval()\n", " with torch.no_grad():\n", " for batch_idx, (data, _) in enumerate(data_loader):\n", " data = data.to(device)\n", " masked_data = mask_data(data, device, masks)\n", " masked_ind = np.argwhere(masked_data[0][0]==0)\n", " reconstructed = model(masked_data)\n", " recon_no2 = reconstructed[0][0]\n", " ori_no2 = data[0][0]\n", " return" ] }, { "cell_type": "code", "execution_count": null, "id": "743d1000-561e-4444-8b49-88346c14f28b", "metadata": {}, "outputs": [], "source": [ "model = model.to(device)\n", "\n", "num_epochs = 100\n", "train_losses = list()\n", "val_losses = list()\n", "for epoch in range(num_epochs):\n", " train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n", " train_losses.append(train_loss)\n", " val_loss = evaluate(model, device, val_loader, criterion)\n", " val_losses.append(val_loss)\n", " print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n", "\n", "# 测试模型\n", "test_loss = evaluate(model, device, test_loader, criterion)\n", "print(f'Test Loss: {test_loss}')" ] }, { "cell_type": "code", "execution_count": null, "id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18", "metadata": {}, "outputs": [], "source": [ "tr_ind = list(range(len(train_losses)))\n", "val_ind = list(range(len(val_losses)))\n", "plt.plot(train_losses, label='train_loss')\n", "plt.plot(val_losses, label='val_loss')\n", "plt.legend(loc='best')" ] }, { "cell_type": "code", "execution_count": null, "id": "cff8cba9-aba9-4347-8e1a-f169df8313c2", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " device = 'cpu'\n", " for batch_idx, (data, _) in enumerate(test_loader):\n", " model = model.to(device)\n", " data = data.to(device)\n", " masked_data = mask_data(data, device, masks)\n", " reconstructed = model(masked_data)\n", " tr_maxs = np.transpose(maxs, (2, 0, 1))\n", " tr_mins = np.transpose(mins, (2, 0, 1))\n", " rev_data = data * (tr_maxs - tr_mins) + tr_mins\n", " rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n", " data_label = ((rev_data!=0) * (masked_data==0) * rev_data)[:, 0]\n", " recon_no2 = ((rev_data!=0) * (masked_data==0) * rev_recon)[:, 0]\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "565f2e0a-1689-4a03-9fc1-15519b1cdaee", "metadata": {}, "outputs": [], "source": [ "real = data_label.flatten()\n", "pred = recon_no2.flatten()" ] }, { "cell_type": "code", "execution_count": null, "id": "e1e8f71e-855a-41ea-b62f-095514af66a3", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error" ] }, { "cell_type": "code", "execution_count": null, "id": "0a6eea29-cd3e-4712-ad73-589bcf7b88be", "metadata": {}, "outputs": [], "source": [ "mean_squared_error(real, pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "a85feefb-aa3a-4bb9-86ac-7cc6938a47e8", "metadata": {}, "outputs": [], "source": [ "mean_absolute_percentage_error(real, pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "c2bfda87-3de8-4a06-969f-d346f4447cf6", "metadata": {}, "outputs": [], "source": [ "r2_score(real, pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "d1955fc0-490a-40d5-8b3c-dd6e5beed235", "metadata": {}, "outputs": [], "source": [ "mean_absolute_error(real, pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf", "metadata": {}, "outputs": [], "source": [ "visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')" ] }, { "cell_type": "code", "execution_count": 3, "id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "62" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len('The total $R^2$ for under 40\\% missing data test set was 0.88.')" ] }, { "cell_type": "code", "execution_count": null, "id": "215938c7-d514-48e7-a460-088dcd7927ae", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }