import math import os import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo bn_mom = 0.0003 class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, activate_first=True, inplace=True): super(SeparableConv2d, self).__init__() self.relu0 = nn.ReLU(inplace=inplace) self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) self.bn1 = nn.BatchNorm2d(in_channels, momentum=bn_mom) self.relu1 = nn.ReLU(inplace=True) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) self.bn2 = nn.BatchNorm2d(out_channels, momentum=bn_mom) self.relu2 = nn.ReLU(inplace=True) self.activate_first = activate_first def forward(self, x): if self.activate_first: x = self.relu0(x) x = self.depthwise(x) x = self.bn1(x) if not self.activate_first: x = self.relu1(x) x = self.pointwise(x) x = self.bn2(x) if not self.activate_first: x = self.relu2(x) return x class Block(nn.Module): def __init__(self, in_filters, out_filters, strides=1, atrous=None, grow_first=True, activate_first=True, inplace=True): super(Block, self).__init__() if atrous is None: atrous = [1] * 3 elif isinstance(atrous, int): atrous_list = [atrous] * 3 atrous = atrous_list idx = 0 self.head_relu = True if out_filters != in_filters or strides != 1: self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) self.skipbn = nn.BatchNorm2d(out_filters, momentum=bn_mom) self.head_relu = False else: self.skip = None self.hook_layer = None if grow_first: filters = out_filters else: filters = in_filters self.sepconv1 = SeparableConv2d(in_filters, filters, 3, stride=1, padding=1 * atrous[0], dilation=atrous[0], bias=False, activate_first=activate_first, inplace=self.head_relu) self.sepconv2 = SeparableConv2d(filters, out_filters, 3, stride=1, padding=1 * atrous[1], dilation=atrous[1], bias=False, activate_first=activate_first) self.sepconv3 = SeparableConv2d(out_filters, out_filters, 3, stride=strides, padding=1 * atrous[2], dilation=atrous[2], bias=False, activate_first=activate_first, inplace=inplace) def forward(self, inp): if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x = self.sepconv1(inp) x = self.sepconv2(x) self.hook_layer = x x = self.sepconv3(x) x += skip return x class Xception(nn.Module): """ Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf """ def __init__(self, downsample_factor): """ Constructor Args: num_classes: number of classes """ super(Xception, self).__init__() stride_list = None if downsample_factor == 8: stride_list = [2, 1, 1] elif downsample_factor == 16: stride_list = [2, 2, 1] else: raise ValueError( 'xception.py: output stride=%d is not supported.' % os) self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) self.bn1 = nn.BatchNorm2d(32, momentum=bn_mom) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=bn_mom) #do relu here self.block1 = Block(64, 128, 2) self.block2 = Block(128, 256, stride_list[0], inplace=False) self.block3 = Block(256, 728, stride_list[1]) rate = 16 // downsample_factor self.block4 = Block(728, 728, 1, atrous=rate) self.block5 = Block(728, 728, 1, atrous=rate) self.block6 = Block(728, 728, 1, atrous=rate) self.block7 = Block(728, 728, 1, atrous=rate) self.block8 = Block(728, 728, 1, atrous=rate) self.block9 = Block(728, 728, 1, atrous=rate) self.block10 = Block(728, 728, 1, atrous=rate) self.block11 = Block(728, 728, 1, atrous=rate) self.block12 = Block(728, 728, 1, atrous=rate) self.block13 = Block(728, 728, 1, atrous=rate) self.block14 = Block(728, 728, 1, atrous=rate) self.block15 = Block(728, 728, 1, atrous=rate) self.block16 = Block(728, 728, 1, atrous=[1 * rate, 1 * rate, 1 * rate]) self.block17 = Block(728, 728, 1, atrous=[1 * rate, 1 * rate, 1 * rate]) self.block18 = Block(728, 728, 1, atrous=[1 * rate, 1 * rate, 1 * rate]) self.block19 = Block(728, 728, 1, atrous=[1 * rate, 1 * rate, 1 * rate]) self.block20 = Block(728, 1024, stride_list[2], atrous=rate, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1 * rate, dilation=rate, activate_first=False) self.conv4 = SeparableConv2d(1536, 1536, 3, 1, 1 * rate, dilation=rate, activate_first=False) self.conv5 = SeparableConv2d(1536, 2048, 3, 1, 1 * rate, dilation=rate, activate_first=False) self.layers = [] #------- init weights -------- for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() #----------------------------- def forward(self, input): self.layers = [] x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) low_featrue_layer = self.block2.hook_layer x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.block13(x) x = self.block14(x) x = self.block15(x) x = self.block16(x) x = self.block17(x) x = self.block18(x) x = self.block19(x) x = self.block20(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) return low_featrue_layer, x def load_url(url, model_dir='./model_data', map_location=None): if not os.path.exists(model_dir): os.makedirs(model_dir) filename = url.split('/')[-1] cached_file = os.path.join(model_dir, filename) if os.path.exists(cached_file): return torch.load(cached_file, map_location=map_location) else: return model_zoo.load_url(url, model_dir=model_dir) def xception(pretrained=True, downsample_factor=16): model = Xception(downsample_factor=downsample_factor) if pretrained: model.load_state_dict(load_url( 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth' ), strict=False) return model