MAE_ATMO/torch_MAE.ipynb

590 lines
16 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "adf69eb9-bedb-4db7-87c4-04c23752a7c3",
"metadata": {},
"outputs": [],
"source": [
"def load_data(pix, use_type='train'):\n",
" datasets = list()\n",
" file_list = [x for x in os.listdir(f\"./out_mat/{pix}/{use_type}/\") if x.endswith('.npy')]\n",
" for file in file_list:\n",
" file_img = np.load(f\"./out_mat/{pix}/{use_type}/{file}\")[:,:,:7]\n",
" datasets.append(file_img)\n",
" return np.asarray(datasets)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e0aa628f-37b7-498a-94d7-81241c20b305",
"metadata": {},
"outputs": [],
"source": [
"train_set = load_data(96, 'train')\n",
"val_set = load_data(96, 'valid')\n",
"test_set = load_data(96, 'test')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5d5f95cb-f40c-4ead-96fe-241068408b98",
"metadata": {},
"outputs": [],
"source": [
"def load_mask(mask_rate):\n",
" mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}')\n",
" masks = list()\n",
" for file in mask_files:\n",
" d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE)\n",
" d = (d > 0) * 1\n",
" masks.append(d)\n",
" return np.asarray(masks)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "71452a77-8158-46b2-aecf-400ad7b72df5",
"metadata": {},
"outputs": [],
"source": [
"masks = load_mask(20)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1902e0f8-32bb-4376-8238-334260b12623",
"metadata": {},
"outputs": [],
"source": [
"maxs = train_set.max(axis=0)\n",
"mins = train_set.min(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8df9f3c3-ced8-4640-af30-b2f147dbdc96",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"26749"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_set)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "53664b12-fd95-4dd0-b61d-20682f8f14f4",
"metadata": {},
"outputs": [],
"source": [
"norm_train = (train_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "05cb9dc8-c1df-48bf-a9dd-d084ce1d2068",
"metadata": {},
"outputs": [],
"source": [
"del train_set"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4ae39364-4cf6-49e9-b99f-6723520943b5",
"metadata": {},
"outputs": [],
"source": [
"norm_valid = (val_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7f78b981-d079-4000-ba9f-d862e34903b1",
"metadata": {},
"outputs": [],
"source": [
"del val_set"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f54eede6-e95a-4476-b822-79846c0b1079",
"metadata": {},
"outputs": [],
"source": [
"norm_test = (test_set - mins) / (maxs-mins)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e66887eb-df5e-46d3-b9c5-73af1272b27a",
"metadata": {},
"outputs": [],
"source": [
"del test_set"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "00afa8cd-18b4-4d71-8cab-fd140058dca3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(26749, 96, 96)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"norm_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31d91072-3878-4e3c-b6f1-09f597faf60d",
"metadata": {},
"outputs": [],
"source": [
"trans_train = np.transpose(norm_train, (0, 3, 1, 2))\n",
"trans_val = np.transpose(norm_valid, (0, 3, 1, 2))\n",
"trans_test = np.transpose(norm_test, (0, 3, 1, 2))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "70797703-1619-4be7-b965-5506b3d1e775",
"metadata": {},
"outputs": [],
"source": [
"# 可视化特定特征的函数\n",
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(input_feature[0].cpu().numpy())\n",
" plt.title(title + \" Input\")\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(masked_feature[0].cpu().numpy())\n",
" plt.title(title + \" Masked\")\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(output_feature[0].detach().cpu().numpy())\n",
" plt.title(title + \" Recovery\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
"metadata": {},
"outputs": [],
"source": [
"# 设置随机种子以确保结果的可重复性\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)\n",
"\n",
"# 数据准备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"# 将numpy数组转换为PyTorch张量\n",
"tensor_train = torch.tensor(trans_train.astype(np.float32), device=device)\n",
"tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device)\n",
"tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1569baeb-5a9e-48c1-a735-82d0cba8ad29",
"metadata": {},
"outputs": [],
"source": [
"# 创建一个数据集和数据加载器\n",
"train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器\n",
"val_set = TensorDataset(tensor_valid, tensor_valid)\n",
"test_set = TensorDataset(tensor_test, tensor_test)\n",
"batch_size = 64\n",
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n",
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c81785d-f0e6-486f-8aad-dba81d2ec146",
"metadata": {},
"outputs": [],
"source": [
"def mask_data(data, device, masks):\n",
" mask_inds = np.random.choice(masks.shape[0], data.shape[0])\n",
" mask = torch.from_numpy(masks[mask_inds]).to(device)\n",
" tmp_first_channel = data[:, 0, :, :] * mask\n",
" masked_data = torch.clone(data)\n",
" masked_data[:, 0, :, :] = tmp_first_channel\n",
" return masked_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, in_channels, reduced_dim):\n",
" super(SEBlock, self).__init__()\n",
" self.se = nn.Sequential(\n",
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return x * self.se(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 定义Masked Autoencoder模型\n",
"class MaskedAutoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(MaskedAutoencoder, self).__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(7, 32, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" SEBlock(128, 128)\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(16, 7, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
" )\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
" return decoded\n",
"\n",
"# 实例化模型、损失函数和优化器\n",
"model = MaskedAutoencoder()\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
"metadata": {},
"outputs": [],
"source": [
"# 训练函数\n",
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" masked_data = mask_data(data, device, masks)\n",
" optimizer.zero_grad()\n",
" reconstructed = model(masked_data)\n",
" loss = criterion(reconstructed, data)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
"metadata": {},
"outputs": [],
"source": [
"# 评估函数\n",
"def evaluate(model, device, data_loader, criterion):\n",
" model.eval()\n",
" running_loss = 0.0\n",
" with torch.no_grad():\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" reconstructed = model(masked_data)\n",
" if batch_idx == 8:\n",
" rand_ind = np.random.randint(0, len(data))\n",
" visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
" loss = criterion(reconstructed, data)\n",
" running_loss += loss.item()\n",
" return running_loss / (batch_idx + 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1847c78-cbc6-4560-bb49-4dc6e9b8bbd0",
"metadata": {},
"outputs": [],
"source": [
"# 测试函数\n",
"def test(model, device, data_loader):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for batch_idx, (data, _) in enumerate(data_loader):\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" masked_ind = np.argwhere(masked_data[0][0]==0)\n",
" reconstructed = model(masked_data)\n",
" recon_no2 = reconstructed[0][0]\n",
" ori_no2 = data[0][0]\n",
" return"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "743d1000-561e-4444-8b49-88346c14f28b",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"\n",
"num_epochs = 100\n",
"train_losses = list()\n",
"val_losses = list()\n",
"for epoch in range(num_epochs):\n",
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
" train_losses.append(train_loss)\n",
" val_loss = evaluate(model, device, val_loader, criterion)\n",
" val_losses.append(val_loss)\n",
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
"\n",
"# 测试模型\n",
"test_loss = evaluate(model, device, test_loader, criterion)\n",
"print(f'Test Loss: {test_loss}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
"metadata": {},
"outputs": [],
"source": [
"tr_ind = list(range(len(train_losses)))\n",
"val_ind = list(range(len(val_losses)))\n",
"plt.plot(train_losses, label='train_loss')\n",
"plt.plot(val_losses, label='val_loss')\n",
"plt.legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cff8cba9-aba9-4347-8e1a-f169df8313c2",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" device = 'cpu'\n",
" for batch_idx, (data, _) in enumerate(test_loader):\n",
" model = model.to(device)\n",
" data = data.to(device)\n",
" masked_data = mask_data(data, device, masks)\n",
" reconstructed = model(masked_data)\n",
" tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
" tr_mins = np.transpose(mins, (2, 0, 1))\n",
" rev_data = data * (tr_maxs - tr_mins) + tr_mins\n",
" rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n",
" data_label = ((rev_data!=0) * (masked_data==0) * rev_data)[:, 0]\n",
" recon_no2 = ((rev_data!=0) * (masked_data==0) * rev_recon)[:, 0]\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "565f2e0a-1689-4a03-9fc1-15519b1cdaee",
"metadata": {},
"outputs": [],
"source": [
"real = data_label.flatten()\n",
"pred = recon_no2.flatten()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1e8f71e-855a-41ea-b62f-095514af66a3",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a6eea29-cd3e-4712-ad73-589bcf7b88be",
"metadata": {},
"outputs": [],
"source": [
"mean_squared_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a85feefb-aa3a-4bb9-86ac-7cc6938a47e8",
"metadata": {},
"outputs": [],
"source": [
"mean_absolute_percentage_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2bfda87-3de8-4a06-969f-d346f4447cf6",
"metadata": {},
"outputs": [],
"source": [
"r2_score(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1955fc0-490a-40d5-8b3c-dd6e5beed235",
"metadata": {},
"outputs": [],
"source": [
"mean_absolute_error(real, pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
"metadata": {},
"outputs": [],
"source": [
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"62"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len('The total $R^2$ for under 40\\% missing data test set was 0.88.')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "215938c7-d514-48e7-a460-088dcd7927ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}