138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
'''
|
|
Description:
|
|
Date: 2023-07-21 14:36:27
|
|
LastEditTime: 2023-07-27 18:41:47
|
|
FilePath: /chengdongzhou/ScConv.py
|
|
'''
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
|
|
class GroupBatchnorm2d(nn.Module):
|
|
def __init__(self, c_num: int,
|
|
group_num: int = 16,
|
|
eps: float = 1e-10
|
|
):
|
|
super(GroupBatchnorm2d, self).__init__()
|
|
assert c_num >= group_num
|
|
self.group_num = group_num
|
|
self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
|
|
self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
N, C, H, W = x.size()
|
|
x = x.view(N, self.group_num, -1)
|
|
mean = x.mean(dim=2, keepdim=True)
|
|
std = x.std(dim=2, keepdim=True)
|
|
x = (x - mean) / (std + self.eps)
|
|
x = x.view(N, C, H, W)
|
|
return x * self.weight + self.bias
|
|
|
|
|
|
class SRU(nn.Module):
|
|
def __init__(self,
|
|
oup_channels: int,
|
|
group_num: int = 16,
|
|
gate_treshold: float = 0.5,
|
|
torch_gn: bool = False
|
|
):
|
|
super().__init__()
|
|
|
|
self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
|
|
c_num=oup_channels, group_num=group_num)
|
|
self.gate_treshold = gate_treshold
|
|
self.sigomid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
gn_x = self.gn(x)
|
|
w_gamma = self.gn.weight / torch.sum(self.gn.weight)
|
|
w_gamma = w_gamma.view(1, -1, 1, 1)
|
|
reweigts = self.sigomid(gn_x * w_gamma)
|
|
# Gate
|
|
info_mask = reweigts >= self.gate_treshold
|
|
noninfo_mask = reweigts < self.gate_treshold
|
|
x_1 = info_mask * gn_x
|
|
x_2 = noninfo_mask * gn_x
|
|
x = self.reconstruct(x_1, x_2)
|
|
return x
|
|
|
|
def reconstruct(self, x_1, x_2):
|
|
x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
|
|
x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
|
|
return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
|
|
|
|
|
|
class CRU(nn.Module):
|
|
'''
|
|
alpha: 0<alpha<1
|
|
'''
|
|
|
|
def __init__(self,
|
|
op_channel: int,
|
|
alpha: float = 1 / 2,
|
|
squeeze_radio: int = 2,
|
|
group_size: int = 2,
|
|
group_kernel_size: int = 3,
|
|
):
|
|
super().__init__()
|
|
self.up_channel = up_channel = int(alpha * op_channel)
|
|
self.low_channel = low_channel = op_channel - up_channel
|
|
self.squeeze1 = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)
|
|
self.squeeze2 = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)
|
|
# up
|
|
self.GWC = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1,
|
|
padding=group_kernel_size // 2, groups=group_size)
|
|
self.PWC1 = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)
|
|
# low
|
|
self.PWC2 = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1,
|
|
bias=False)
|
|
self.advavg = nn.AdaptiveAvgPool2d(1)
|
|
|
|
def forward(self, x):
|
|
# Split
|
|
up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
|
|
up, low = self.squeeze1(up), self.squeeze2(low)
|
|
# Transform
|
|
Y1 = self.GWC(up) + self.PWC1(up)
|
|
Y2 = torch.cat([self.PWC2(low), low], dim=1)
|
|
# Fuse
|
|
out = torch.cat([Y1, Y2], dim=1)
|
|
out = F.softmax(self.advavg(out), dim=1) * out
|
|
out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
|
|
return out1 + out2
|
|
|
|
|
|
class ScConv(nn.Module):
|
|
def __init__(self,
|
|
op_channel: int,
|
|
group_num: int = 4,
|
|
gate_treshold: float = 0.5,
|
|
alpha: float = 1 / 2,
|
|
squeeze_radio: int = 2,
|
|
group_size: int = 2,
|
|
group_kernel_size: int = 3,
|
|
):
|
|
super().__init__()
|
|
self.SRU = SRU(op_channel,
|
|
group_num=group_num,
|
|
gate_treshold=gate_treshold)
|
|
self.CRU = CRU(op_channel,
|
|
alpha=alpha,
|
|
squeeze_radio=squeeze_radio,
|
|
group_size=group_size,
|
|
group_kernel_size=group_kernel_size)
|
|
|
|
def forward(self, x):
|
|
x = self.SRU(x)
|
|
x = self.CRU(x)
|
|
return x
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
x = torch.randn(1, 32, 16, 16)
|
|
model = ScConv(32)
|
|
print(model(x).shape)
|