import math from functools import partial from timm.models.efficientnet_blocks import SqueezeExcite as SE from einops import rearrange, reduce from timm.models.layers.activations import * from timm.models.layers import DropPath inplace = True # ========== For Common ========== class LayerNorm2d(nn.Module): def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): super().__init__() self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine) def forward(self, x): x = rearrange(x, 'b c h w -> b h w c').contiguous() x = self.norm(x) x = rearrange(x, 'b h w c -> b c h w').contiguous() return x def get_norm(norm_layer='in_1d'): eps = 1e-6 norm_dict = { 'none': nn.Identity, 'in_1d': partial(nn.InstanceNorm1d, eps=eps), 'in_2d': partial(nn.InstanceNorm2d, eps=eps), 'in_3d': partial(nn.InstanceNorm3d, eps=eps), 'bn_1d': partial(nn.BatchNorm1d, eps=eps), 'bn_2d': partial(nn.BatchNorm2d, eps=eps), # 'bn_2d': partial(nn.SyncBatchNorm, eps=eps), 'bn_3d': partial(nn.BatchNorm3d, eps=eps), 'gn': partial(nn.GroupNorm, eps=eps), 'ln_1d': partial(nn.LayerNorm, eps=eps), 'ln_2d': partial(LayerNorm2d, eps=eps), } return norm_dict[norm_layer] def get_act(act_layer='relu'): act_dict = { 'none': nn.Identity, 'sigmoid': Sigmoid, 'swish': Swish, 'mish': Mish, 'hsigmoid': HardSigmoid, 'hswish': HardSwish, 'hmish': HardMish, 'tanh': Tanh, 'relu': nn.ReLU, 'relu6': nn.ReLU6, 'prelu': PReLU, 'gelu': GELU, 'silu': nn.SiLU } return act_dict[act_layer] class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=True): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(1, 1, dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class LayerScale2D(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=True): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(1, dim, 1, 1)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class ConvNormAct(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False, skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.): super(ConvNormAct, self).__init__() self.has_skip = skip and dim_in == dim_out padding = math.ceil((kernel_size - stride) / 2) self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias) self.norm = get_norm(norm_layer)(dim_out) self.act = get_act(act_layer)(inplace=inplace) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def forward(self, x): shortcut = x x = self.conv(x) x = self.norm(x) x = self.act(x) if self.has_skip: x = self.drop_path(x) + shortcut return x # ========== Multi-Scale Populations, for down-sampling and inductive bias ========== class MSPatchEmb(nn.Module): def __init__(self, dim_in, emb_dim, kernel_size=2, c_group=-1, stride=1, dilations=[1, 2, 3], norm_layer='bn_2d', act_layer='silu'): super().__init__() self.dilation_num = len(dilations) assert dim_in % c_group == 0 c_group = math.gcd(dim_in, emb_dim) if c_group == -1 else c_group self.convs = nn.ModuleList() for i in range(len(dilations)): padding = math.ceil(((kernel_size - 1) * dilations[i] + 1 - stride) / 2) self.convs.append(nn.Sequential( nn.Conv2d(dim_in, emb_dim, kernel_size, stride, padding, dilations[i], groups=c_group), get_norm(norm_layer)(emb_dim), get_act(act_layer)(emb_dim))) def forward(self, x): if self.dilation_num == 1: x = self.convs[0](x) else: x = torch.cat([self.convs[i](x).unsqueeze(dim=-1) for i in range(self.dilation_num)], dim=-1) x = reduce(x, 'b c h w n -> b c h w', 'mean').contiguous() return x class iRMB(nn.Module): def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d', act_layer='relu', v_proj=True, dw_ks=3, stride=1, dilation=1, se_ratio=0.0, dim_head=64, window_size=7, attn_s=True, qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False): super().__init__() self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity() dim_mid = int(dim_in * exp_ratio) self.has_skip = (dim_in == dim_out and stride == 1) and has_skip self.attn_s = attn_s if self.attn_s: assert dim_in % dim_head == 0, 'dim should be divisible by num_heads' self.dim_head = dim_head self.window_size = window_size self.num_head = dim_in // dim_head self.scale = self.dim_head ** -0.5 self.attn_pre = attn_pre self.qk = ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias, norm_layer='none', act_layer='none') self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, groups=self.num_head if v_group else 1, bias=qkv_bias, norm_layer='none', act_layer=act_layer, inplace=inplace) self.attn_drop = nn.Dropout(attn_drop) else: if v_proj: self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias, norm_layer='none', act_layer=act_layer, inplace=inplace) else: self.v = nn.Identity() self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace) self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity() self.proj_drop = nn.Dropout(drop) self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace) self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def forward(self, x): shortcut = x x = self.norm(x) B, C, H, W = x.shape if self.attn_s: # padding if self.window_size <= 0: window_size_W, window_size_H = W, H else: window_size_W, window_size_H = self.window_size, self.window_size pad_l, pad_t = 0, 0 pad_r = (window_size_W - W % window_size_W) % window_size_W pad_b = (window_size_H - H % window_size_H) % window_size_H x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,)) n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous() # attention b, c, h, w = x.shape qk = self.qk(x) qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head', qk=2, heads=self.num_head, dim_head=self.dim_head).contiguous() q, k = qk[0], qk[1] attn_spa = (q @ k.transpose(-2, -1)) * self.scale attn_spa = attn_spa.softmax(dim=-1) attn_spa = self.attn_drop(attn_spa) if self.attn_pre: x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous() x_spa = attn_spa @ x x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h, w=w).contiguous() x_spa = self.v(x_spa) else: v = self.v(x) v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous() x_spa = attn_spa @ v x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h, w=w).contiguous() # unpadding x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous() if pad_r > 0 or pad_b > 0: x = x[:, :, :H, :W].contiguous() else: x = self.v(x) x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x)) x = self.proj_drop(x) x = self.proj(x) x = (shortcut + self.drop_path(x)) if self.has_skip else x return x # 输入 N C H W, 输出 N C H W if __name__ == '__main__': input = torch.randn(3, 64, 64, 64).cuda() model = iRMB(64, 64).cuda() output = model(input) print(output.shape)