1069 lines
334 KiB
Plaintext
1069 lines
334 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "6603a8fc-d9da-4037-b845-d9c38bae4ce4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import cv2"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "adf69eb9-bedb-4db7-87c4-04c23752a7c3",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_data(pix, use_type='train'):\n",
|
||
|
" datasets = list()\n",
|
||
|
" file_list = [x for x in os.listdir(f\"./out_mat/{pix}/{use_type}/\") if x.endswith('.npy')]\n",
|
||
|
" for file in file_list:\n",
|
||
|
" file_img = np.load(f\"./out_mat/{pix}/{use_type}/{file}\")[:,:,:1]\n",
|
||
|
" datasets.append(file_img)\n",
|
||
|
" return np.asarray(datasets)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "e0aa628f-37b7-498a-94d7-81241c20b305",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_set = load_data(96, 'train')\n",
|
||
|
"val_set = load_data(96, 'valid')\n",
|
||
|
"test_set = load_data(96, 'test')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "5d5f95cb-f40c-4ead-96fe-241068408b98",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_mask(mask_rate):\n",
|
||
|
" mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}')\n",
|
||
|
" masks = list()\n",
|
||
|
" for file in mask_files:\n",
|
||
|
" d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE)\n",
|
||
|
" d = (d > 0) * 1\n",
|
||
|
" masks.append(d)\n",
|
||
|
" return np.asarray(masks)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "71452a77-8158-46b2-aecf-400ad7b72df5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"masks = load_mask(20)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "1902e0f8-32bb-4376-8238-334260b12623",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"maxs = train_set.max(axis=0)\n",
|
||
|
"mins = train_set.min(axis=0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "8df9f3c3-ced8-4640-af30-b2f147dbdc96",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"26749"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"len(train_set)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "53664b12-fd95-4dd0-b61d-20682f8f14f4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"norm_train = (train_set - mins) / (maxs-mins)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "05cb9dc8-c1df-48bf-a9dd-d084ce1d2068",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"del train_set"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "4ae39364-4cf6-49e9-b99f-6723520943b5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"norm_valid = (val_set - mins) / (maxs-mins)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "7f78b981-d079-4000-ba9f-d862e34903b1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"del val_set"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "f54eede6-e95a-4476-b822-79846c0b1079",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"norm_test = (test_set - mins) / (maxs-mins)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "e66887eb-df5e-46d3-b9c5-73af1272b27a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"del test_set"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "00afa8cd-18b4-4d71-8cab-fd140058dca3",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"norm_train.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "31d91072-3878-4e3c-b6f1-09f597faf60d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"trans_train = np.transpose(norm_train, (0, 3, 1, 2))\n",
|
||
|
"trans_val = np.transpose(norm_valid, (0, 3, 1, 2))\n",
|
||
|
"trans_test = np.transpose(norm_test, (0, 3, 1, 2))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "70797703-1619-4be7-b965-5506b3d1e775",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 可视化特定特征的函数\n",
|
||
|
"def visualize_feature(input_feature,masked_feature, output_feature, title):\n",
|
||
|
" plt.figure(figsize=(12, 6))\n",
|
||
|
" plt.subplot(1, 3, 1)\n",
|
||
|
" plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
|
" plt.title(title + \" Input\")\n",
|
||
|
" plt.subplot(1, 3, 2)\n",
|
||
|
" plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
|
" plt.title(title + \" Masked\")\n",
|
||
|
" plt.subplot(1, 3, 3)\n",
|
||
|
" plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r')\n",
|
||
|
" plt.title(title + \" Recovery\")\n",
|
||
|
" plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "aeda3567-4c4d-496b-9570-9ae757b45e72",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 设置随机种子以确保结果的可重复性\n",
|
||
|
"torch.manual_seed(0)\n",
|
||
|
"np.random.seed(0)\n",
|
||
|
"\n",
|
||
|
"# 数据准备\n",
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)\n",
|
||
|
"# 将numpy数组转换为PyTorch张量\n",
|
||
|
"tensor_train = torch.tensor(trans_train.astype(np.float32), device=device)\n",
|
||
|
"tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device)\n",
|
||
|
"tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "2353265d-91ef-4a84-b582-ea969d2ee252",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"del trans_train\n",
|
||
|
"del trans_val\n",
|
||
|
"del trans_test"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "1569baeb-5a9e-48c1-a735-82d0cba8ad29",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 创建一个数据集和数据加载器\n",
|
||
|
"train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器\n",
|
||
|
"val_set = TensorDataset(tensor_valid, tensor_valid)\n",
|
||
|
"test_set = TensorDataset(tensor_test, tensor_test)\n",
|
||
|
"batch_size = 64\n",
|
||
|
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
|
||
|
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n",
|
||
|
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "3c81785d-f0e6-486f-8aad-dba81d2ec146",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def mask_data(data, device, masks):\n",
|
||
|
" mask_inds = np.random.choice(masks.shape[0], data.shape[0])\n",
|
||
|
" mask = torch.from_numpy(masks[mask_inds]).to(device)\n",
|
||
|
" tmp_first_channel = data[:, 0, :, :] * mask\n",
|
||
|
" masked_data = torch.clone(data)\n",
|
||
|
" masked_data[:, 0, :, :] = tmp_first_channel\n",
|
||
|
" return masked_data, mask"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "645114e8-65a4-4867-b3fe-23395288e855",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Conv(nn.Sequential):\n",
|
||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):\n",
|
||
|
" super(Conv, self).__init__(\n",
|
||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "2af52d0e-b785-4a84-838c-6fcfe2568722",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ConvBNReLU(nn.Sequential):\n",
|
||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,\n",
|
||
|
" bias=False):\n",
|
||
|
" super(ConvBNReLU, self).__init__(\n",
|
||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,\n",
|
||
|
" dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),\n",
|
||
|
" norm_layer(out_channels),\n",
|
||
|
" nn.ReLU()\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "31ecf247-e98b-4977-a145-782914a042bd",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SeparableBNReLU(nn.Sequential):\n",
|
||
|
" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):\n",
|
||
|
" super(SeparableBNReLU, self).__init__(\n",
|
||
|
" nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation,\n",
|
||
|
" padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False),\n",
|
||
|
" # 分离卷积,仅调整空间信息\n",
|
||
|
" norm_layer(in_channels), # 对输入通道进行归一化\n",
|
||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作\n",
|
||
|
" nn.ReLU6()\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "7827bee2-74f7-4e47-b8c6-e41d5670e8b9",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ResidualBlock(nn.Module):\n",
|
||
|
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
|
||
|
" super(ResidualBlock, self).__init__()\n",
|
||
|
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
|
||
|
" self.bn1 = nn.BatchNorm2d(out_channels)\n",
|
||
|
" self.relu = nn.ReLU(inplace=True)\n",
|
||
|
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)\n",
|
||
|
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
|
||
|
"\n",
|
||
|
" # 如果输入和输出通道不一致,进行降采样操作\n",
|
||
|
" self.downsample = downsample\n",
|
||
|
" if in_channels != out_channels or stride != 1:\n",
|
||
|
" self.downsample = nn.Sequential(\n",
|
||
|
" nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
|
||
|
" nn.BatchNorm2d(out_channels)\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" identity = x\n",
|
||
|
" if self.downsample is not None:\n",
|
||
|
" identity = self.downsample(x)\n",
|
||
|
"\n",
|
||
|
" out = self.conv1(x)\n",
|
||
|
" out = self.bn1(out)\n",
|
||
|
" out = self.relu(out)\n",
|
||
|
"\n",
|
||
|
" out = self.conv2(out)\n",
|
||
|
" out = self.bn2(out)\n",
|
||
|
"\n",
|
||
|
" out += identity\n",
|
||
|
" out = self.relu(out)\n",
|
||
|
" return out\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "7853bf62-02f5-4917-b950-6fdfe467df4a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Mlp(nn.Module):\n",
|
||
|
" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):\n",
|
||
|
" super().__init__()\n",
|
||
|
" out_features = out_features or in_features\n",
|
||
|
" hidden_features = hidden_features or in_features\n",
|
||
|
" self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n",
|
||
|
"\n",
|
||
|
" self.act = act_layer()\n",
|
||
|
" self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n",
|
||
|
" self.drop = nn.Dropout(drop, inplace=True)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.fc1(x)\n",
|
||
|
" x = self.act(x)\n",
|
||
|
" x = self.drop(x)\n",
|
||
|
" x = self.fc2(x)\n",
|
||
|
" x = self.drop(x)\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "e2375881-a11b-47a7-8f56-2eadb25010b0",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class MultiHeadAttentionBlock(nn.Module):\n",
|
||
|
" def __init__(self, embed_dim, num_heads, dropout=0.1):\n",
|
||
|
" super(MultiHeadAttentionBlock, self).__init__()\n",
|
||
|
" self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)\n",
|
||
|
" self.norm = nn.LayerNorm(embed_dim)\n",
|
||
|
" self.dropout = nn.Dropout(dropout)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility\n",
|
||
|
" B, C, H, W = x.shape\n",
|
||
|
" x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C)\n",
|
||
|
"\n",
|
||
|
" # Apply multihead attention\n",
|
||
|
" attn_output, _ = self.attention(x, x, x)\n",
|
||
|
"\n",
|
||
|
" # Apply normalization and dropout\n",
|
||
|
" attn_output = self.norm(attn_output)\n",
|
||
|
" attn_output = self.dropout(attn_output)\n",
|
||
|
"\n",
|
||
|
" # Reshape back to (B, C, H, W)\n",
|
||
|
" attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)\n",
|
||
|
"\n",
|
||
|
" return attn_output"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "eea9678d-e170-4dd5-bf96-d20af4d40184",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Help on built-in function mean in module torch:\n",
|
||
|
"\n",
|
||
|
"mean(...)\n",
|
||
|
" mean(input, *, dtype=None) -> Tensor\n",
|
||
|
" \n",
|
||
|
" Returns the mean value of all elements in the :attr:`input` tensor.\n",
|
||
|
" \n",
|
||
|
" Args:\n",
|
||
|
" input (Tensor): the input tensor.\n",
|
||
|
" \n",
|
||
|
" Keyword args:\n",
|
||
|
" dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.\n",
|
||
|
" If specified, the input tensor is casted to :attr:`dtype` before the operation\n",
|
||
|
" is performed. This is useful for preventing data type overflows. Default: None.\n",
|
||
|
" \n",
|
||
|
" Example::\n",
|
||
|
" \n",
|
||
|
" >>> a = torch.randn(1, 3)\n",
|
||
|
" >>> a\n",
|
||
|
" tensor([[ 0.2294, -0.5481, 1.3288]])\n",
|
||
|
" >>> torch.mean(a)\n",
|
||
|
" tensor(0.3367)\n",
|
||
|
" \n",
|
||
|
" .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor\n",
|
||
|
" :noindex:\n",
|
||
|
" \n",
|
||
|
" Returns the mean value of each row of the :attr:`input` tensor in the given\n",
|
||
|
" dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,\n",
|
||
|
" reduce over all of them.\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" If :attr:`keepdim` is ``True``, the output tensor is of the same size\n",
|
||
|
" as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.\n",
|
||
|
" Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the\n",
|
||
|
" output tensor having 1 (or ``len(dim)``) fewer dimension(s).\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" Args:\n",
|
||
|
" input (Tensor): the input tensor.\n",
|
||
|
" dim (int or tuple of ints): the dimension or dimensions to reduce.\n",
|
||
|
" keepdim (bool): whether the output tensor has :attr:`dim` retained or not.\n",
|
||
|
" \n",
|
||
|
" Keyword args:\n",
|
||
|
" dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.\n",
|
||
|
" If specified, the input tensor is casted to :attr:`dtype` before the operation\n",
|
||
|
" is performed. This is useful for preventing data type overflows. Default: None.\n",
|
||
|
" out (Tensor, optional): the output tensor.\n",
|
||
|
" \n",
|
||
|
" .. seealso::\n",
|
||
|
" \n",
|
||
|
" :func:`torch.nanmean` computes the mean value of `non-NaN` elements.\n",
|
||
|
" \n",
|
||
|
" Example::\n",
|
||
|
" \n",
|
||
|
" >>> a = torch.randn(4, 4)\n",
|
||
|
" >>> a\n",
|
||
|
" tensor([[-0.3841, 0.6320, 0.4254, -0.7384],\n",
|
||
|
" [-0.9644, 1.0131, -0.6549, -1.4279],\n",
|
||
|
" [-0.2951, -1.3350, -0.7694, 0.5600],\n",
|
||
|
" [ 1.0842, -0.9580, 0.3623, 0.2343]])\n",
|
||
|
" >>> torch.mean(a, 1)\n",
|
||
|
" tensor([-0.0163, -0.5085, -0.4599, 0.1807])\n",
|
||
|
" >>> torch.mean(a, 1, True)\n",
|
||
|
" tensor([[-0.0163],\n",
|
||
|
" [-0.5085],\n",
|
||
|
" [-0.4599],\n",
|
||
|
" [ 0.1807]])\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"help(torch.mean)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "82a15d3d-2f8d-42ec-9146-87c8a4abe384",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SpatialAttentionBlock(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(SpatialAttentionBlock, self).__init__()\n",
|
||
|
" self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x): #(B, 64, H, W)\n",
|
||
|
" avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W)\n",
|
||
|
" max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W)\n",
|
||
|
" out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W)\n",
|
||
|
" out = torch.sigmoid(self.conv(out))#(B, 1, H, W)\n",
|
||
|
" return x * out #(B, C, H, W)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "497bb9f1-1ac5-4d7f-a930-0ea222b9d1d9",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class DecoderAttentionBlock(nn.Module):\n",
|
||
|
" def __init__(self, in_channels):\n",
|
||
|
" super(DecoderAttentionBlock, self).__init__()\n",
|
||
|
" self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)\n",
|
||
|
" self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)\n",
|
||
|
" self.spatial_attention = SpatialAttentionBlock()\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" # 通道注意力\n",
|
||
|
" b, c, h, w = x.size()\n",
|
||
|
" avg_pool = F.adaptive_avg_pool2d(x, 1)\n",
|
||
|
" max_pool = F.adaptive_max_pool2d(x, 1)\n",
|
||
|
"\n",
|
||
|
" avg_out = self.conv1(avg_pool)\n",
|
||
|
" max_out = self.conv1(max_pool)\n",
|
||
|
"\n",
|
||
|
" out = avg_out + max_out\n",
|
||
|
" out = torch.sigmoid(self.conv2(out))\n",
|
||
|
"\n",
|
||
|
" # 添加空间注意力\n",
|
||
|
" out = x * out\n",
|
||
|
" out = self.spatial_attention(out)\n",
|
||
|
" return out"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "15b9d453-d8d9-43b8-aca2-904735fb3a99",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SEBlock(nn.Module):\n",
|
||
|
" def __init__(self, in_channels, reduced_dim):\n",
|
||
|
" super(SEBlock, self).__init__()\n",
|
||
|
" self.se = nn.Sequential(\n",
|
||
|
" nn.AdaptiveAvgPool2d(1), # 全局平均池化\n",
|
||
|
" nn.Conv2d(in_channels, reduced_dim, kernel_size=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Conv2d(reduced_dim, in_channels, kernel_size=1),\n",
|
||
|
" nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" return x * self.se(x)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "6379adb7-8a87-4dd8-a695-4013a7b37830",
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 定义Masked Autoencoder模型\n",
|
||
|
"class MaskedAutoencoder(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(MaskedAutoencoder, self).__init__()\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" Conv(1, 32, kernel_size=3, stride=2),\n",
|
||
|
" \n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" \n",
|
||
|
" SEBlock(32,32),\n",
|
||
|
" \n",
|
||
|
" ConvBNReLU(32, 64, kernel_size=3, stride=2),\n",
|
||
|
" \n",
|
||
|
" ResidualBlock(64,64),\n",
|
||
|
" \n",
|
||
|
" SeparableBNReLU(64, 128, kernel_size=3, stride=2),\n",
|
||
|
" \n",
|
||
|
" MultiHeadAttentionBlock(embed_dim=128, num_heads=4),\n",
|
||
|
" \n",
|
||
|
" SEBlock(128, 128)\n",
|
||
|
" )\n",
|
||
|
" self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1)\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" \n",
|
||
|
" DecoderAttentionBlock(32),\n",
|
||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" \n",
|
||
|
" DecoderAttentionBlock(16),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" \n",
|
||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1\n",
|
||
|
" nn.Sigmoid()\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" encoded = self.encoder(x)\n",
|
||
|
" decoded = self.decoder(encoded)\n",
|
||
|
" return decoded\n",
|
||
|
"\n",
|
||
|
"# 实例化模型、损失函数和优化器\n",
|
||
|
"model = MaskedAutoencoder()\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "fea37b5f-817d-4850-8393-36910cf64eb2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 定义Masked Autoencoder模型\n",
|
||
|
"class MaskedAutoencoderBase(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(MaskedAutoencoderBase, self).__init__()\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" SEBlock(128, 128)\n",
|
||
|
" )\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" encoded = self.encoder(x)\n",
|
||
|
" decoded = self.decoder(encoded)\n",
|
||
|
" return decoded\n",
|
||
|
"\n",
|
||
|
"# 实例化模型、损失函数和优化器\n",
|
||
|
"model = MaskedAutoencoderBase()\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "404a8bfb-4976-4cce-b989-c5e401bce0d7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 训练函数\n",
|
||
|
"def train_epoch(model, device, data_loader, criterion, optimizer):\n",
|
||
|
" model.train()\n",
|
||
|
" running_loss = 0.0\n",
|
||
|
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||
|
" masked_data, mask = mask_data(data, device, masks)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" reconstructed = model(masked_data)\n",
|
||
|
" loss = criterion(reconstructed, data)\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" running_loss += loss.item()\n",
|
||
|
" return running_loss / (batch_idx + 1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "94457c6b-4c6e-4aff-946d-fe4c670bfe16",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 评估函数\n",
|
||
|
"def evaluate(model, device, data_loader, criterion):\n",
|
||
|
" model.eval()\n",
|
||
|
" running_loss = 0.0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" masked_data, mask = mask_data(data, device, masks)\n",
|
||
|
" reconstructed = model(masked_data)\n",
|
||
|
" if batch_idx == 8:\n",
|
||
|
" rand_ind = np.random.randint(0, len(data))\n",
|
||
|
" visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2')\n",
|
||
|
" loss = criterion(reconstructed, data)\n",
|
||
|
" running_loss += loss.item()\n",
|
||
|
" return running_loss / (batch_idx + 1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a1847c78-cbc6-4560-bb49-4dc6e9b8bbd0",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# 测试函数\n",
|
||
|
"def test(model, device, data_loader):\n",
|
||
|
" model.eval()\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for batch_idx, (data, _) in enumerate(data_loader):\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" masked_data, mask = mask_data(data, device, masks)\n",
|
||
|
" masked_ind = np.argwhere(masked_data[0][0]==0)\n",
|
||
|
" reconstructed = model(masked_data)\n",
|
||
|
" recon_no2 = reconstructed[0][0]\n",
|
||
|
" ori_no2 = data[0][0]\n",
|
||
|
" return"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "743d1000-561e-4444-8b49-88346c14f28b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model = model.to(device)\n",
|
||
|
"\n",
|
||
|
"num_epochs = 100\n",
|
||
|
"train_losses = list()\n",
|
||
|
"val_losses = list()\n",
|
||
|
"for epoch in range(num_epochs):\n",
|
||
|
" train_loss = train_epoch(model, device, train_loader, criterion, optimizer)\n",
|
||
|
" train_losses.append(train_loss)\n",
|
||
|
" val_loss = evaluate(model, device, val_loader, criterion)\n",
|
||
|
" val_losses.append(val_loss)\n",
|
||
|
" print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')\n",
|
||
|
"\n",
|
||
|
"# 测试模型\n",
|
||
|
"test_loss = evaluate(model, device, test_loader, criterion)\n",
|
||
|
"print(f'Test Loss: {test_loss}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "cdc0d608-6f0a-43dc-8cc1-8acf68215d18",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"tr_ind = list(range(len(train_losses)))\n",
|
||
|
"val_ind = list(range(len(val_losses)))\n",
|
||
|
"plt.plot(train_losses, label='train_loss')\n",
|
||
|
"plt.plot(val_losses, label='val_loss')\n",
|
||
|
"plt.legend(loc='best')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a8467686-0655-4056-8e01-56299eb89d7c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "dae7427e-548e-4276-a4ea-bc9b279d44e8",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"real_list = list()\n",
|
||
|
"pred_list = list()\n",
|
||
|
"with torch.no_grad():\n",
|
||
|
" device = 'cpu'\n",
|
||
|
" for batch_idx, (data, _) in enumerate(test_loader):\n",
|
||
|
" model = model.to(device)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" masked_data, mask = mask_data(data, device, masks)\n",
|
||
|
" mask_rev = (mask==0) * 1 # mask取反获得修复区域\n",
|
||
|
" reconstructed = model(masked_data)\n",
|
||
|
" tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
|
||
|
" tr_mins = np.transpose(mins, (2, 0, 1))\n",
|
||
|
" rev_data = data * (tr_maxs - tr_mins) + tr_mins\n",
|
||
|
" rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n",
|
||
|
" # todo: 这里需要只评估修补出来的模块\n",
|
||
|
" data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
||
|
" data_label = data_label[mask_rev==1]\n",
|
||
|
" recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
||
|
" recon_no2 = recon_no2[mask_rev==1]\n",
|
||
|
" real_list.extend(data_label)\n",
|
||
|
" pred_list.extend(recon_no2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "94e58640-42a9-4d54-a851-c7fc3a6e06ce",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"abs(np.asarray(real_list) - np.asarray(pred_list))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 84,
|
||
|
"id": "acee2abc-2f3f-4d19-a6e4-85ad4d1aaacf",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9EAAAFTCAYAAAA+14+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d7wlR3Ev/u3umTnn3r27Ky2KqxwACZFlMkgYJEQwWAYTjEl6Dhgj2xj7GTA2Avv3jEVwAmOCcZKxDdgGHk+2yMHGJItsEQVCSEJhtdpw99xzZqa7f390V3d1z5x77y4Ku6spfVb3nDkzPT2ppr5V36oS1lqLQQYZZJBBBhlkkEEGGWSQQQYZZE2Rd/QEBhlkkEEGGWSQQQYZZJBBBhnkQJEBRA8yyCCDDDLIIIMMMsgggwwyyDplANGDDDLIIIMMMsgggwwyyCCDDLJOGUD0IIMMMsgggwwyyCCDDDLIIIOsUwYQPcgggwwyyCCDDDLIIIMMMsgg65QBRA8yyCCDDDLIIIMMMsgggwwyyDplANGDDDLIIIMMMsgggwwyyCCDDLJOGUD0IIMMMsgggwwyyCCDDDLIIIOsUwYQPcgggwwyyCCDDDLIIIMMMsgg65QBRA8yyCCDDDLIIOuWE088ET/xEz9xm+7jla98JYQQt+k+BhlkkEEGGWRfZQDRg3Tkb/7mbyCEwHg8xrXXXtv5/ZGPfCTuec97dpY3TYM/+7M/wwMe8ABs3LgRS0tLeMADHoA/+7M/Q9M0ybqTyQR//ud/jsc85jE4+uijsXHjRtzvfvfDX/zFX0Brva55CiFw4YUX7ttB3gbyD//wD/iTP/mTO3oagwwyyD7IgaT3hBD4+Z//+d7fX/7yl4d1tm3btq4xBxlkkEEONB1I/zZt2oSzzz4bl1566b4d+CCD7KMMIHqQuTKbzfCHf/iH61p3z549OPfcc/Frv/ZrOOqoo/CHf/iHeO1rX4utW7fi137t13Duuediz549Yf3vfve7+JVf+RVYa/HiF78Yr3vd63DSSSfhl3/5l/G//tf/uq0O6TaVAUQPMsiBLweC3huPx/iXf/kX1HXd+e0f//EfMR6P1z3WIIMMMgiXA0EHnnvuubjkkkvwd3/3d/it3/otfOc738ETn/hEfOADH9jr4x1kkH0WO8ggmfz1X/+1BWDve9/72tFoZK+99trk97PPPtueccYZybJf/MVftADsG97whs54b3zjGy0A+0u/9Eth2U033WS/9rWvdda94IILLAD77W9/e815ArAvfOEL13tYt7k84QlPsCeccMIdPY1BBhlkH+RA0nvnn3++lVLa9773vclvn/rUpywA+5SnPMUCsDfddNOa4+2LnHDCCfYJT3jCbTI2yUUXXWQHE2WQQW4/OZB0YG77XXHFFRaAfdzjHrfm9geDLC8v39FTGMRaO0SiB5krv/3bvw2t9ZoeyWuuuQZvf/vb8ahHPaqXXv3CF74QP/7jP46//Mu/xDXXXAMAOOyww3DGGWd01v2pn/opAMDXv/71vZ7vxz/+cQgh8K53vQv/5//8Hxx77LEYj8d49KMfje985zvJukRLuvzyy/HQhz4UCwsLOOmkk/DmN785WY/oTVdddVXvvj7+8Y+H8S699FJ8//vfDxSjE088ca+PYZBBBrlj5UDQe8cccwzOOuss/MM//EOy/B3veAfuda979VIu/+M//gNPfepTcfzxx2M0GuG4447Dr//6r2NlZSVZ7/rrr8cFF1yAY489FqPRCEcffTR+8id/sqMDc/nbv/1bFEWB//2//3dY9tnPfhaPfexjsXnzZiwuLuLss8/Gpz71qc62//mf/4kHPOABGI/HOOWUU/CWt7xlXedhkEEGufXlQNCBuZx++uk47LDDcOWVVybLZ7MZLrroIpx66qlB7/3Wb/0WZrNZZ4y///u/xwMf+EAsLi7i0EMPxVlnnYUPfvCDyTpvetObcMYZZ2A0GmHr1q144QtfiB07doTfL7zwQiwtLWEymXTG/5mf+RkcddRRCW393//93/GIRzwCGzZswMaNG/GEJzwB//M//5Ns97znPQ9LS0u48sor8fjHPx4bN27Ez/7sz+Kiiy5CWZa46aabOvv6xV/8RRxyyCGYTqfrOn+D7JsMIHqQuXLSSSfhOc95Dt72trfhuuuum7vev//7v0Nrjec85zlz13nOc56Dtm1x2WWXrbrP66+/HoBTtPsqf/iHf4j3vOc9+M3f/E287GUvw2c+8xn87M/+bGe9W265BY9//ONx5pln4jWveQ2OPfZYvOAFL8Bf/dVf7fU+X/7yl+O+970vDjvsMFxyySW45JJLBmr3IIMcgHKg6L1nPvOZeP/734/l5WUAQNu2ePe7341nPvOZveu/+93vxmQywQte8AK84Q1vwHnnnYc3vOENnfk/5SlPwXve8x5ccMEFeNOb3oRf/dVfxe7du3H11VfPnctb3/pWXHDBBXjpS1+K1772tQCAj370ozjrrLOwa9cuXHTRRfiDP/gD7NixA4961KPwuc99Lmz71a9+FY95zGNw44034pWvfCUuuOACXHTRRXjPe96z7nMxyCCD3HpyoOhALjt37sQtt9yCQw89NCwzxuBJT3oSXve61+GJT3wi3vCGN+D888/HH//xH+PpT396sv2rXvUqPPvZz0ZZlvi93/s9vOpVr8Jxxx2Hj370o2GdV77ylXjhC1+IrVu34vWvfz2e8pSn4C1veQse85jHhNzvpz/96dizZ08nP3symeD9738/fvqnfxpKKQDAJZdcgic84QlYWlrCxRdfjN/93d/FFVdcgYc//OEdp2XbtjjvvPNwxBFH4HWvex2e8pSn4NnPfjbatsU73/nOZN26rvHP//zPeMpTnjKk9tzWckeHwgfZ/4QoPZ///OftlVdeaYuisL/6q78afs8pPS960YssAPvFL35x7phf+MIXLAD74he/eO46s9nM3uMe97AnnXSSbZpmzXkio/R87GMfswDs6aefbmezWVj+p3/6pxaA/epXv5ocAwD7+te/Ptn/fe97X3vEEUfYuq6Tc/G9730v2Tft62Mf+1hYNtC5BxnkwJUDTe9t377dVlVlL7nkEmuttZdeeqkVQtirrroqUKE5nXsymXTGevWrX22FEPb73/++tdbaW265xQKwr33ta1edA6dz/+mf/qkVQtjf//3fD78bY+xd73pXe95551ljTDKHk046yZ577rlh2fnnn2/H43GYg7WOmqmUGujcgwxyO8qBpAN/7ud+zt500032xhtvtP/93/9tH/vYx3Z01yWXXGKllPY//uM/ku3f/OY3WwD2U5/6lLXW2m9/+9tWSml/6qd+ymqtk3VJf9144422qir7mMc8JlmHKOt/9Vd/FdY/5phj7FOe8pRknHe9610WgP3kJz9prbV29+7d9pBDDrG/8Au/kKx3/fXX282bNyfLn/vc51oA9qUvfWnnXDzkIQ+xD3rQg5Jl//qv/9qxTwe5bWSIRA+yqpx88sl49rOfjbe+9a344Q9/2LvO7t27AQAbN26cOw79tmvXrrnrXHjhhbjiiivwxje+EUVR7POcL7jgAlRVFb4/4hGPAOAKWnApigLPf/7zw/eqqvD85z8fN954Iy6//PJ93v8ggwxyYMuBoPcOPfRQPPaxj8U//uM/AnCFDR/60IfihBNO6F1/YWEhfN6zZw+2bduGhz70obDW4otf/GJYp6oqfPzjH8ctt9yy5hxe85rX4Nd+7ddw8cUX43d+53fC8i996Uv49re/jWc+85m4+eabsW3bNmzbtg179uzBox/9aHzyk5+EMQZaa3zgAx/A+eefj+OPPz5sf/rpp+O8885b97kYZJBBbl3Z33Xg29/+dhx++OE44ogj8GM/9mP4yEc+gt/6rd/Ci1/84rDOu9/9bpx++uk47bTTgg7atm0bHvWoRwEAPvaxjwEA3vve98IYg1e84hWQMoVF1Gbvwx/+MOq6xote9KJknV/4hV/Apk2bQuRZCIGnPvWp+Ld/+7fAEgKAd77znTjmmGPw8Ic/HADwoQ99CDt27MDP/MzPJHN
|
||
|
"text/plain": [
|
||
|
"<Figure size 1200x600 with 3 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "1ada99bf-6bea-4e46-a3bd-f62510517c8e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# real_list = list()\n",
|
||
|
"# pred_list = list()\n",
|
||
|
"# with torch.no_grad():\n",
|
||
|
"# device = 'cpu'\n",
|
||
|
"# for batch_idx, (data, _) in enumerate(test_loader):\n",
|
||
|
"# model = model.to(device)\n",
|
||
|
"# data = data.to(device)\n",
|
||
|
"# masked_data, mask = mask_data(data, device, masks)\n",
|
||
|
"# mask_rev = (mask==0) * 1 # mask取反获得修复区域\n",
|
||
|
"# reconstructed = model(masked_data)\n",
|
||
|
"# tr_maxs = np.transpose(maxs, (2, 0, 1))\n",
|
||
|
"# tr_mins = np.transpose(mins, (2, 0, 1))\n",
|
||
|
"# rev_data = data * (tr_maxs - tr_mins) + tr_mins\n",
|
||
|
"# rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins\n",
|
||
|
"# # todo: 这里需要只评估修补出来的模块\n",
|
||
|
"# data_label = torch.squeeze(rev_data, dim=1) * mask_rev\n",
|
||
|
"# data_label = data_label[mask_rev==1]\n",
|
||
|
"# recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev\n",
|
||
|
"# recon_no2 = recon_no2[mask_rev==1]\n",
|
||
|
"# real_list.extend(data_label)\n",
|
||
|
"# pred_list.extend(recon_no2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "8332744e-5b90-4702-a3b7-66309ffb1956",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"a = torch.randn(1, 1, 4, 4)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "216bce16-246e-4431-95e7-2c3a9d894fe2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"avg_out = torch.mean(a, dim=1, keepdim=True) #(B, 1, H, W)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "0b954708-269f-4b5a-ad65-03ecf58a9549",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([1, 1, 4, 4])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"avg_out.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "31ac2d6d-79c6-4ed8-a9e5-0ec37a6a9e4a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[-0.0919, 1.9463, -0.6934, 0.1982],\n",
|
||
|
" [ 0.1241, 0.5442, 0.4565, 0.3567],\n",
|
||
|
" [ 0.8672, -0.8656, -0.4287, -0.4634],\n",
|
||
|
" [ 1.8194, 0.3727, 1.1409, 0.6761]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a[0][0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "4c660fa7-851b-456c-9881-88f81079121c",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[-0.0919, 1.9463, -0.6934, 0.1982],\n",
|
||
|
" [ 0.1241, 0.5442, 0.4565, 0.3567],\n",
|
||
|
" [ 0.8672, -0.8656, -0.4287, -0.4634],\n",
|
||
|
" [ 1.8194, 0.3727, 1.1409, 0.6761]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"avg_out[0][0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "5905c4ff-613b-4f08-a7a1-2bafb4fc0ba2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "279cb531-aefc-4be2-8d98-b09c3c595a9a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[[[-0.0919, 1.9463, -0.6934, 0.1982],\n",
|
||
|
" [ 0.1241, 0.5442, 0.4565, 0.3567],\n",
|
||
|
" [ 0.8672, -0.8656, -0.4287, -0.4634],\n",
|
||
|
" [ 1.8194, 0.3727, 1.1409, 0.6761]]]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"avg_out"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "19ae1030-1d4d-4a0b-b307-412456f27f47",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[[[-0.0919, 1.9463, -0.6934, 0.1982],\n",
|
||
|
" [ 0.1241, 0.5442, 0.4565, 0.3567],\n",
|
||
|
" [ 0.8672, -0.8656, -0.4287, -0.4634],\n",
|
||
|
" [ 1.8194, 0.3727, 1.1409, 0.6761]]]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"avg_out"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "e10712cd-45fc-44a3-b359-5a62cae1c33c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|