''' Description: Date: 2023-07-21 14:36:27 LastEditTime: 2023-07-27 18:41:47 FilePath: /chengdongzhou/ScConv.py ''' import torch import torch.nn.functional as F import torch.nn as nn class GroupBatchnorm2d(nn.Module): def __init__(self, c_num: int, group_num: int = 16, eps: float = 1e-10 ): super(GroupBatchnorm2d, self).__init__() assert c_num >= group_num self.group_num = group_num self.weight = nn.Parameter(torch.randn(c_num, 1, 1)) self.bias = nn.Parameter(torch.zeros(c_num, 1, 1)) self.eps = eps def forward(self, x): N, C, H, W = x.size() x = x.view(N, self.group_num, -1) mean = x.mean(dim=2, keepdim=True) std = x.std(dim=2, keepdim=True) x = (x - mean) / (std + self.eps) x = x.view(N, C, H, W) return x * self.weight + self.bias class SRU(nn.Module): def __init__(self, oup_channels: int, group_num: int = 16, gate_treshold: float = 0.5, torch_gn: bool = False ): super().__init__() self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d( c_num=oup_channels, group_num=group_num) self.gate_treshold = gate_treshold self.sigomid = nn.Sigmoid() def forward(self, x): gn_x = self.gn(x) w_gamma = self.gn.weight / torch.sum(self.gn.weight) w_gamma = w_gamma.view(1, -1, 1, 1) reweigts = self.sigomid(gn_x * w_gamma) # Gate info_mask = reweigts >= self.gate_treshold noninfo_mask = reweigts < self.gate_treshold x_1 = info_mask * gn_x x_2 = noninfo_mask * gn_x x = self.reconstruct(x_1, x_2) return x def reconstruct(self, x_1, x_2): x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1) x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1) return torch.cat([x_11 + x_22, x_12 + x_21], dim=1) class CRU(nn.Module): ''' alpha: 0