42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class GAM_Attention(nn.Module):
|
|
def __init__(self, in_channels, rate=4):
|
|
super(GAM_Attention, self).__init__()
|
|
|
|
self.channel_attention = nn.Sequential(
|
|
nn.Linear(in_channels, int(in_channels / rate)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(int(in_channels / rate), in_channels)
|
|
)
|
|
|
|
self.spatial_attention = nn.Sequential(
|
|
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
|
|
nn.BatchNorm2d(int(in_channels / rate)),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
|
|
nn.BatchNorm2d(in_channels)
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
|
|
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
|
|
x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
|
|
|
|
x = x * x_channel_att
|
|
|
|
x_spatial_att = self.spatial_attention(x).sigmoid()
|
|
out = x * x_spatial_att
|
|
|
|
return out
|
|
|
|
# 输入 N C HW, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
x = torch.randn(1, 64, 20, 20)
|
|
b, c, h, w = x.shape
|
|
net = GAM_Attention(in_channels=c)
|
|
y = net(x)
|
|
print(y.size()) |