334 KiB
334 KiB
In [1]:
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset, random_split import os import numpy as np import matplotlib.pyplot as plt import cv2
In [2]:
def load_data(pix, use_type='train'): datasets = list() file_list = [x for x in os.listdir(f"./out_mat/{pix}/{use_type}/") if x.endswith('.npy')] for file in file_list: file_img = np.load(f"./out_mat/{pix}/{use_type}/{file}")[:,:,:1] datasets.append(file_img) return np.asarray(datasets)
In [3]:
train_set = load_data(96, 'train') val_set = load_data(96, 'valid') test_set = load_data(96, 'test')
In [4]:
def load_mask(mask_rate): mask_files = os.listdir(f'./out_mat/96/mask/{mask_rate}') masks = list() for file in mask_files: d = cv2.imread(f'./out_mat/96/mask/{mask_rate}/{file}', cv2.IMREAD_GRAYSCALE) d = (d > 0) * 1 masks.append(d) return np.asarray(masks)
In [5]:
masks = load_mask(20)
In [6]:
maxs = train_set.max(axis=0) mins = train_set.min(axis=0)
In [7]:
len(train_set)
Out[7]:
26749
In [ ]:
norm_train = (train_set - mins) / (maxs-mins)
In [ ]:
del train_set
In [ ]:
norm_valid = (val_set - mins) / (maxs-mins)
In [ ]:
del val_set
In [ ]:
norm_test = (test_set - mins) / (maxs-mins)
In [ ]:
del test_set
In [ ]:
norm_train.shape
In [ ]:
trans_train = np.transpose(norm_train, (0, 3, 1, 2)) trans_val = np.transpose(norm_valid, (0, 3, 1, 2)) trans_test = np.transpose(norm_test, (0, 3, 1, 2))
In [ ]:
# 可视化特定特征的函数 def visualize_feature(input_feature,masked_feature, output_feature, title): plt.figure(figsize=(12, 6)) plt.subplot(1, 3, 1) plt.imshow(input_feature[0].cpu().numpy(), cmap='RdYlGn_r') plt.title(title + " Input") plt.subplot(1, 3, 2) plt.imshow(masked_feature[0].cpu().numpy(), cmap='RdYlGn_r') plt.title(title + " Masked") plt.subplot(1, 3, 3) plt.imshow(output_feature[0].detach().cpu().numpy(), cmap='RdYlGn_r') plt.title(title + " Recovery") plt.show()
In [ ]:
# 设置随机种子以确保结果的可重复性 torch.manual_seed(0) np.random.seed(0) # 数据准备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # 将numpy数组转换为PyTorch张量 tensor_train = torch.tensor(trans_train.astype(np.float32), device=device) tensor_valid = torch.tensor(trans_val.astype(np.float32), device=device) tensor_test = torch.tensor(trans_test.astype(np.float32), device=device)
In [ ]:
del trans_train del trans_val del trans_test
In [ ]:
# 创建一个数据集和数据加载器 train_set = TensorDataset(tensor_train, tensor_train) # 输出和标签相同,因为我们是自编码器 val_set = TensorDataset(tensor_valid, tensor_valid) test_set = TensorDataset(tensor_test, tensor_test) batch_size = 64 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
In [ ]:
def mask_data(data, device, masks): mask_inds = np.random.choice(masks.shape[0], data.shape[0]) mask = torch.from_numpy(masks[mask_inds]).to(device) tmp_first_channel = data[:, 0, :, :] * mask masked_data = torch.clone(data) masked_data[:, 0, :, :] = tmp_first_channel return masked_data, mask
In [ ]:
class Conv(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False): super(Conv, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2) )
In [ ]:
class ConvBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBNReLU, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels), nn.ReLU() )
In [ ]:
class SeparableBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): super(SeparableBNReLU, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), # 分离卷积,仅调整空间信息 norm_layer(in_channels), # 对输入通道进行归一化 nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), # 这里进行升维操作 nn.ReLU6() )
In [ ]:
class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) # 如果输入和输出通道不一致,进行降采样操作 self.downsample = downsample if in_channels != out_channels or stride != 1: self.downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out
In [ ]:
class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True) self.drop = nn.Dropout(drop, inplace=True) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
In [ ]:
class MultiHeadAttentionBlock(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super(MultiHeadAttentionBlock, self).__init__() self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # (B, C, H, W) -> (HW, B, C) for MultiheadAttention compatibility B, C, H, W = x.shape x = x.view(B, C, H * W).permute(2, 0, 1) # (B, C, H, W) -> (HW, B, C) # Apply multihead attention attn_output, _ = self.attention(x, x, x) # Apply normalization and dropout attn_output = self.norm(attn_output) attn_output = self.dropout(attn_output) # Reshape back to (B, C, H, W) attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) return attn_output
In [8]:
help(torch.mean)
Help on built-in function mean in module torch: mean(...) mean(input, *, dtype=None) -> Tensor Returns the mean value of all elements in the :attr:`input` tensor. Args: input (Tensor): the input tensor. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. Example:: >>> a = torch.randn(1, 3) >>> a tensor([[ 0.2294, -0.5481, 1.3288]]) >>> torch.mean(a) tensor(0.3367) .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor :noindex: Returns the mean value of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, reduce over all of them. If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the output tensor having 1 (or ``len(dim)``) fewer dimension(s). Args: input (Tensor): the input tensor. dim (int or tuple of ints): the dimension or dimensions to reduce. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. out (Tensor, optional): the output tensor. .. seealso:: :func:`torch.nanmean` computes the mean value of `non-NaN` elements. Example:: >>> a = torch.randn(4, 4) >>> a tensor([[-0.3841, 0.6320, 0.4254, -0.7384], [-0.9644, 1.0131, -0.6549, -1.4279], [-0.2951, -1.3350, -0.7694, 0.5600], [ 1.0842, -0.9580, 0.3623, 0.2343]]) >>> torch.mean(a, 1) tensor([-0.0163, -0.5085, -0.4599, 0.1807]) >>> torch.mean(a, 1, True) tensor([[-0.0163], [-0.5085], [-0.4599], [ 0.1807]])
In [ ]:
class SpatialAttentionBlock(nn.Module): def __init__(self): super(SpatialAttentionBlock, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False) def forward(self, x): #(B, 64, H, W) avg_out = torch.mean(x, dim=1, keepdim=True) #(B, 1, H, W) max_out, _ = torch.max(x, dim=1, keepdim=True)#(B, 1, H, W) out = torch.cat([avg_out, max_out], dim=1)#(B, 2, H, W) out = torch.sigmoid(self.conv(out))#(B, 1, H, W) return x * out #(B, C, H, W)
In [ ]:
class DecoderAttentionBlock(nn.Module): def __init__(self, in_channels): super(DecoderAttentionBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1) self.conv2 = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1) self.spatial_attention = SpatialAttentionBlock() def forward(self, x): # 通道注意力 b, c, h, w = x.size() avg_pool = F.adaptive_avg_pool2d(x, 1) max_pool = F.adaptive_max_pool2d(x, 1) avg_out = self.conv1(avg_pool) max_out = self.conv1(max_pool) out = avg_out + max_out out = torch.sigmoid(self.conv2(out)) # 添加空间注意力 out = x * out out = self.spatial_attention(out) return out
In [ ]:
class SEBlock(nn.Module): def __init__(self, in_channels, reduced_dim): super(SEBlock, self).__init__() self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), # 全局平均池化 nn.Conv2d(in_channels, reduced_dim, kernel_size=1), nn.ReLU(), nn.Conv2d(reduced_dim, in_channels, kernel_size=1), nn.Sigmoid() # 使用Sigmoid是因为我们要对通道进行权重归一化 ) def forward(self, x): return x * self.se(x)
In [ ]:
# 定义Masked Autoencoder模型 class MaskedAutoencoder(nn.Module): def __init__(self): super(MaskedAutoencoder, self).__init__() self.encoder = nn.Sequential( Conv(1, 32, kernel_size=3, stride=2), nn.ReLU(), SEBlock(32,32), ConvBNReLU(32, 64, kernel_size=3, stride=2), ResidualBlock(64,64), SeparableBNReLU(64, 128, kernel_size=3, stride=2), MultiHeadAttentionBlock(embed_dim=128, num_heads=4), SEBlock(128, 128) ) self.mlp = Mlp(in_features=128, hidden_features=256, out_features=128, act_layer=nn.ReLU6, drop=0.1) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), DecoderAttentionBlock(32), nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), DecoderAttentionBlock(16), nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 修改为 output_padding=1 nn.Sigmoid() ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded # 实例化模型、损失函数和优化器 model = MaskedAutoencoder() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
In [ ]:
# 定义Masked Autoencoder模型 class MaskedAutoencoderBase(nn.Module): def __init__(self): super(MaskedAutoencoderBase, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(), SEBlock(128, 128) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sigmoid() # 使用Sigmoid是因为输入数据是0-1之间的 ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded # 实例化模型、损失函数和优化器 model = MaskedAutoencoderBase() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
In [ ]:
# 训练函数 def train_epoch(model, device, data_loader, criterion, optimizer): model.train() running_loss = 0.0 for batch_idx, (data, _) in enumerate(data_loader): masked_data, mask = mask_data(data, device, masks) optimizer.zero_grad() reconstructed = model(masked_data) loss = criterion(reconstructed, data) loss.backward() optimizer.step() running_loss += loss.item() return running_loss / (batch_idx + 1)
In [ ]:
# 评估函数 def evaluate(model, device, data_loader, criterion): model.eval() running_loss = 0.0 with torch.no_grad(): for batch_idx, (data, _) in enumerate(data_loader): data = data.to(device) masked_data, mask = mask_data(data, device, masks) reconstructed = model(masked_data) if batch_idx == 8: rand_ind = np.random.randint(0, len(data)) visualize_feature(data[rand_ind], masked_data[rand_ind], reconstructed[rand_ind], title='NO_2') loss = criterion(reconstructed, data) running_loss += loss.item() return running_loss / (batch_idx + 1)
In [ ]:
# 测试函数 def test(model, device, data_loader): model.eval() with torch.no_grad(): for batch_idx, (data, _) in enumerate(data_loader): data = data.to(device) masked_data, mask = mask_data(data, device, masks) masked_ind = np.argwhere(masked_data[0][0]==0) reconstructed = model(masked_data) recon_no2 = reconstructed[0][0] ori_no2 = data[0][0] return
In [ ]:
model = model.to(device) num_epochs = 100 train_losses = list() val_losses = list() for epoch in range(num_epochs): train_loss = train_epoch(model, device, train_loader, criterion, optimizer) train_losses.append(train_loss) val_loss = evaluate(model, device, val_loader, criterion) val_losses.append(val_loss) print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}') # 测试模型 test_loss = evaluate(model, device, test_loader, criterion) print(f'Test Loss: {test_loss}')
In [ ]:
tr_ind = list(range(len(train_losses))) val_ind = list(range(len(val_losses))) plt.plot(train_losses, label='train_loss') plt.plot(val_losses, label='val_loss') plt.legend(loc='best')
In [ ]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, mean_absolute_error
In [ ]:
real_list = list() pred_list = list() with torch.no_grad(): device = 'cpu' for batch_idx, (data, _) in enumerate(test_loader): model = model.to(device) data = data.to(device) masked_data, mask = mask_data(data, device, masks) mask_rev = (mask==0) * 1 # mask取反获得修复区域 reconstructed = model(masked_data) tr_maxs = np.transpose(maxs, (2, 0, 1)) tr_mins = np.transpose(mins, (2, 0, 1)) rev_data = data * (tr_maxs - tr_mins) + tr_mins rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins # todo: 这里需要只评估修补出来的模块 data_label = torch.squeeze(rev_data, dim=1) * mask_rev data_label = data_label[mask_rev==1] recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev recon_no2 = recon_no2[mask_rev==1] real_list.extend(data_label) pred_list.extend(recon_no2)
In [ ]:
abs(np.asarray(real_list) - np.asarray(pred_list))
In [84]:
visualize_feature(data[5], masked_data[5], reconstructed[5], 'NO2')
In [ ]:
# real_list = list() # pred_list = list() # with torch.no_grad(): # device = 'cpu' # for batch_idx, (data, _) in enumerate(test_loader): # model = model.to(device) # data = data.to(device) # masked_data, mask = mask_data(data, device, masks) # mask_rev = (mask==0) * 1 # mask取反获得修复区域 # reconstructed = model(masked_data) # tr_maxs = np.transpose(maxs, (2, 0, 1)) # tr_mins = np.transpose(mins, (2, 0, 1)) # rev_data = data * (tr_maxs - tr_mins) + tr_mins # rev_recon = reconstructed * (tr_maxs - tr_mins) + tr_mins # # todo: 这里需要只评估修补出来的模块 # data_label = torch.squeeze(rev_data, dim=1) * mask_rev # data_label = data_label[mask_rev==1] # recon_no2 = torch.squeeze(rev_recon, dim=1) * mask_rev # recon_no2 = recon_no2[mask_rev==1] # real_list.extend(data_label) # pred_list.extend(recon_no2)
In [20]:
a = torch.randn(1, 1, 4, 4)
In [21]:
avg_out = torch.mean(a, dim=1, keepdim=True) #(B, 1, H, W)
In [22]:
avg_out.shape
Out[22]:
torch.Size([1, 1, 4, 4])
In [23]:
a[0][0]
Out[23]:
tensor([[-0.0919, 1.9463, -0.6934, 0.1982], [ 0.1241, 0.5442, 0.4565, 0.3567], [ 0.8672, -0.8656, -0.4287, -0.4634], [ 1.8194, 0.3727, 1.1409, 0.6761]])
In [24]:
avg_out[0][0]
Out[24]:
tensor([[-0.0919, 1.9463, -0.6934, 0.1982], [ 0.1241, 0.5442, 0.4565, 0.3567], [ 0.8672, -0.8656, -0.4287, -0.4634], [ 1.8194, 0.3727, 1.1409, 0.6761]])
In [25]:
import numpy as np
In [26]:
avg_out
Out[26]:
tensor([[[[-0.0919, 1.9463, -0.6934, 0.1982], [ 0.1241, 0.5442, 0.4565, 0.3567], [ 0.8672, -0.8656, -0.4287, -0.4634], [ 1.8194, 0.3727, 1.1409, 0.6761]]]])
In [27]:
avg_out
Out[27]:
tensor([[[[-0.0919, 1.9463, -0.6934, 0.1982], [ 0.1241, 0.5442, 0.4565, 0.3567], [ 0.8672, -0.8656, -0.4287, -0.4634], [ 1.8194, 0.3727, 1.1409, 0.6761]]]])
In [ ]: