41 lines
1.8 KiB
Python
41 lines
1.8 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
class EMA(nn.Module):
|
||
|
def __init__(self, channels, factor=8):
|
||
|
super(EMA, self).__init__()
|
||
|
self.groups = factor
|
||
|
assert channels // self.groups > 0
|
||
|
self.softmax = nn.Softmax(-1)
|
||
|
self.agp = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
|
||
|
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
|
||
|
self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
|
||
|
self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
|
||
|
self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
b, c, h, w = x.size()
|
||
|
group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w
|
||
|
x_h = self.pool_h(group_x)
|
||
|
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
|
||
|
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
|
||
|
x_h, x_w = torch.split(hw, [h, w], dim=2)
|
||
|
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
|
||
|
x2 = self.conv3x3(group_x)
|
||
|
x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
|
||
|
x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
|
||
|
x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
|
||
|
x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
|
||
|
weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
|
||
|
return (group_x * weights.sigmoid()).reshape(b, c, h, w)
|
||
|
|
||
|
|
||
|
# 输入 N C HW, 输出 N C H W
|
||
|
if __name__ == '__main__':
|
||
|
block = EMA(64).cuda()
|
||
|
input = torch.rand(1, 64, 64, 64).cuda()
|
||
|
output = block(input)
|
||
|
print(input.size(), output.size())
|