129 lines
5.0 KiB
Python
129 lines
5.0 KiB
Python
# Fcanet: Frequency channel attention networks (ICCV 2021)
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
def get_freq_indices(method):
|
|
assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
|
|
'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
|
|
'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
|
|
num_freq = int(method[3:])
|
|
if 'top' in method:
|
|
all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0,
|
|
0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2, 6, 1]
|
|
all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6,
|
|
3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0, 5, 3]
|
|
mapper_x = all_top_indices_x[:num_freq]
|
|
mapper_y = all_top_indices_y[:num_freq]
|
|
elif 'low' in method:
|
|
all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0,
|
|
1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4]
|
|
all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4,
|
|
3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3]
|
|
mapper_x = all_low_indices_x[:num_freq]
|
|
mapper_y = all_low_indices_y[:num_freq]
|
|
elif 'bot' in method:
|
|
all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4,
|
|
6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5, 3, 6]
|
|
all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2,
|
|
2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3, 3, 3]
|
|
mapper_x = all_bot_indices_x[:num_freq]
|
|
mapper_y = all_bot_indices_y[:num_freq]
|
|
else:
|
|
raise NotImplementedError
|
|
return mapper_x, mapper_y
|
|
|
|
|
|
class MultiSpectralAttentionLayer(nn.Module):
|
|
def __init__(self, channel, dct_h, dct_w, reduction=16, freq_sel_method='top16'):
|
|
super(MultiSpectralAttentionLayer, self).__init__()
|
|
self.reduction = reduction
|
|
self.dct_h = dct_h
|
|
self.dct_w = dct_w
|
|
|
|
mapper_x, mapper_y = get_freq_indices(freq_sel_method)
|
|
self.num_split = len(mapper_x)
|
|
mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
|
|
mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
|
|
# make the frequencies in different sizes are identical to a 7x7 frequency space
|
|
# eg, (2,2) in 14x14 is identical to (1,1) in 7x7
|
|
|
|
self.dct_layer = MultiSpectralDCTLayer(
|
|
dct_h, dct_w, mapper_x, mapper_y, channel)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(channel, channel // reduction, bias=False),
|
|
nn.ReLU(),
|
|
nn.Linear(channel // reduction, channel, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((self.dct_h, self.dct_w))
|
|
|
|
def forward(self, x):
|
|
n, c, h, w = x.shape
|
|
x_pooled = x
|
|
if h != self.dct_h or w != self.dct_w:
|
|
x_pooled = self.avgpool(x)
|
|
# If you have concerns about one-line-change, don't worry. :)
|
|
# In the ImageNet models, this line will never be triggered.
|
|
# This is for compatibility in instance segmentation and object detection.
|
|
y = self.dct_layer(x_pooled)
|
|
|
|
y = self.fc(y).view(n, c, 1, 1)
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
class MultiSpectralDCTLayer(nn.Module):
|
|
"""
|
|
Generate dct filters
|
|
"""
|
|
|
|
def __init__(self, height, width, mapper_x, mapper_y, channel):
|
|
super(MultiSpectralDCTLayer, self).__init__()
|
|
|
|
assert len(mapper_x) == len(mapper_y)
|
|
assert channel % len(mapper_x) == 0
|
|
|
|
self.num_freq = len(mapper_x)
|
|
|
|
# fixed DCT init
|
|
self.weight = self.get_dct_filter(
|
|
height, width, mapper_x, mapper_y, channel)
|
|
|
|
def forward(self, x):
|
|
assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + \
|
|
str(len(x.shape))
|
|
# n, c, h, w = x.shape
|
|
|
|
x = x * self.weight
|
|
result = torch.sum(torch.sum(x, dim=2), dim=2)
|
|
return result
|
|
|
|
def build_filter(self, pos, freq, POS):
|
|
result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
|
|
if freq == 0:
|
|
return result
|
|
else:
|
|
return result * math.sqrt(2)
|
|
|
|
def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
|
|
dct_filter = torch.zeros((channel, tile_size_x, tile_size_y))
|
|
|
|
c_part = channel // len(mapper_x)
|
|
|
|
for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
|
|
for t_x in range(tile_size_x):
|
|
for t_y in range(tile_size_y):
|
|
dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(
|
|
t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)
|
|
|
|
return dct_filter
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
block = MultiSpectralAttentionLayer(64, 16, 16)
|
|
input = torch.rand(1, 64, 64, 64)
|
|
output = block(input)
|
|
print(input.size(), output.size())
|