import torch.nn as nn import torch class GAM_Attention(nn.Module): def __init__(self, in_channels, rate=4): super(GAM_Attention, self).__init__() self.channel_attention = nn.Sequential( nn.Linear(in_channels, int(in_channels / rate)), nn.ReLU(inplace=True), nn.Linear(int(in_channels / rate), in_channels) ) self.spatial_attention = nn.Sequential( nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(in_channels / rate)), nn.ReLU(inplace=True), nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3), nn.BatchNorm2d(in_channels) ) def forward(self, x): b, c, h, w = x.shape x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid() x = x * x_channel_att x_spatial_att = self.spatial_attention(x).sigmoid() out = x * x_spatial_att return out # 输入 N C HW, 输出 N C H W if __name__ == '__main__': x = torch.randn(1, 64, 20, 20) b, c, h, w = x.shape net = GAM_Attention(in_channels=c) y = net(x) print(y.size())