299 lines
10 KiB
Python
299 lines
10 KiB
Python
import math
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.model_zoo as model_zoo
|
|
|
|
bn_mom = 0.0003
|
|
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
bias=False,
|
|
activate_first=True,
|
|
inplace=True):
|
|
super(SeparableConv2d, self).__init__()
|
|
self.relu0 = nn.ReLU(inplace=inplace)
|
|
self.depthwise = nn.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups=in_channels,
|
|
bias=bias)
|
|
self.bn1 = nn.BatchNorm2d(in_channels, momentum=bn_mom)
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
self.pointwise = nn.Conv2d(in_channels,
|
|
out_channels,
|
|
1,
|
|
1,
|
|
0,
|
|
1,
|
|
1,
|
|
bias=bias)
|
|
self.bn2 = nn.BatchNorm2d(out_channels, momentum=bn_mom)
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
self.activate_first = activate_first
|
|
|
|
def forward(self, x):
|
|
if self.activate_first:
|
|
x = self.relu0(x)
|
|
x = self.depthwise(x)
|
|
x = self.bn1(x)
|
|
if not self.activate_first:
|
|
x = self.relu1(x)
|
|
x = self.pointwise(x)
|
|
x = self.bn2(x)
|
|
if not self.activate_first:
|
|
x = self.relu2(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self,
|
|
in_filters,
|
|
out_filters,
|
|
strides=1,
|
|
atrous=None,
|
|
grow_first=True,
|
|
activate_first=True,
|
|
inplace=True):
|
|
super(Block, self).__init__()
|
|
if atrous is None:
|
|
atrous = [1] * 3
|
|
elif isinstance(atrous, int):
|
|
atrous_list = [atrous] * 3
|
|
atrous = atrous_list
|
|
idx = 0
|
|
self.head_relu = True
|
|
if out_filters != in_filters or strides != 1:
|
|
self.skip = nn.Conv2d(in_filters,
|
|
out_filters,
|
|
1,
|
|
stride=strides,
|
|
bias=False)
|
|
self.skipbn = nn.BatchNorm2d(out_filters, momentum=bn_mom)
|
|
self.head_relu = False
|
|
else:
|
|
self.skip = None
|
|
|
|
self.hook_layer = None
|
|
if grow_first:
|
|
filters = out_filters
|
|
else:
|
|
filters = in_filters
|
|
self.sepconv1 = SeparableConv2d(in_filters,
|
|
filters,
|
|
3,
|
|
stride=1,
|
|
padding=1 * atrous[0],
|
|
dilation=atrous[0],
|
|
bias=False,
|
|
activate_first=activate_first,
|
|
inplace=self.head_relu)
|
|
self.sepconv2 = SeparableConv2d(filters,
|
|
out_filters,
|
|
3,
|
|
stride=1,
|
|
padding=1 * atrous[1],
|
|
dilation=atrous[1],
|
|
bias=False,
|
|
activate_first=activate_first)
|
|
self.sepconv3 = SeparableConv2d(out_filters,
|
|
out_filters,
|
|
3,
|
|
stride=strides,
|
|
padding=1 * atrous[2],
|
|
dilation=atrous[2],
|
|
bias=False,
|
|
activate_first=activate_first,
|
|
inplace=inplace)
|
|
|
|
def forward(self, inp):
|
|
|
|
if self.skip is not None:
|
|
skip = self.skip(inp)
|
|
skip = self.skipbn(skip)
|
|
else:
|
|
skip = inp
|
|
|
|
x = self.sepconv1(inp)
|
|
x = self.sepconv2(x)
|
|
self.hook_layer = x
|
|
x = self.sepconv3(x)
|
|
|
|
x += skip
|
|
return x
|
|
|
|
|
|
class Xception(nn.Module):
|
|
"""
|
|
Xception optimized for the ImageNet dataset, as specified in
|
|
https://arxiv.org/pdf/1610.02357.pdf
|
|
"""
|
|
def __init__(self, downsample_factor):
|
|
""" Constructor
|
|
Args:
|
|
num_classes: number of classes
|
|
"""
|
|
super(Xception, self).__init__()
|
|
|
|
stride_list = None
|
|
if downsample_factor == 8:
|
|
stride_list = [2, 1, 1]
|
|
elif downsample_factor == 16:
|
|
stride_list = [2, 2, 1]
|
|
else:
|
|
raise ValueError(
|
|
'xception.py: output stride=%d is not supported.' % os)
|
|
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(32, momentum=bn_mom)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(64, momentum=bn_mom)
|
|
#do relu here
|
|
|
|
self.block1 = Block(64, 128, 2)
|
|
self.block2 = Block(128, 256, stride_list[0], inplace=False)
|
|
self.block3 = Block(256, 728, stride_list[1])
|
|
|
|
rate = 16 // downsample_factor
|
|
self.block4 = Block(728, 728, 1, atrous=rate)
|
|
self.block5 = Block(728, 728, 1, atrous=rate)
|
|
self.block6 = Block(728, 728, 1, atrous=rate)
|
|
self.block7 = Block(728, 728, 1, atrous=rate)
|
|
|
|
self.block8 = Block(728, 728, 1, atrous=rate)
|
|
self.block9 = Block(728, 728, 1, atrous=rate)
|
|
self.block10 = Block(728, 728, 1, atrous=rate)
|
|
self.block11 = Block(728, 728, 1, atrous=rate)
|
|
|
|
self.block12 = Block(728, 728, 1, atrous=rate)
|
|
self.block13 = Block(728, 728, 1, atrous=rate)
|
|
self.block14 = Block(728, 728, 1, atrous=rate)
|
|
self.block15 = Block(728, 728, 1, atrous=rate)
|
|
|
|
self.block16 = Block(728,
|
|
728,
|
|
1,
|
|
atrous=[1 * rate, 1 * rate, 1 * rate])
|
|
self.block17 = Block(728,
|
|
728,
|
|
1,
|
|
atrous=[1 * rate, 1 * rate, 1 * rate])
|
|
self.block18 = Block(728,
|
|
728,
|
|
1,
|
|
atrous=[1 * rate, 1 * rate, 1 * rate])
|
|
self.block19 = Block(728,
|
|
728,
|
|
1,
|
|
atrous=[1 * rate, 1 * rate, 1 * rate])
|
|
|
|
self.block20 = Block(728,
|
|
1024,
|
|
stride_list[2],
|
|
atrous=rate,
|
|
grow_first=False)
|
|
self.conv3 = SeparableConv2d(1024,
|
|
1536,
|
|
3,
|
|
1,
|
|
1 * rate,
|
|
dilation=rate,
|
|
activate_first=False)
|
|
|
|
self.conv4 = SeparableConv2d(1536,
|
|
1536,
|
|
3,
|
|
1,
|
|
1 * rate,
|
|
dilation=rate,
|
|
activate_first=False)
|
|
|
|
self.conv5 = SeparableConv2d(1536,
|
|
2048,
|
|
3,
|
|
1,
|
|
1 * rate,
|
|
dilation=rate,
|
|
activate_first=False)
|
|
self.layers = []
|
|
|
|
#------- init weights --------
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
#-----------------------------
|
|
|
|
def forward(self, input):
|
|
self.layers = []
|
|
x = self.conv1(input)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.block1(x)
|
|
x = self.block2(x)
|
|
low_featrue_layer = self.block2.hook_layer
|
|
x = self.block3(x)
|
|
x = self.block4(x)
|
|
x = self.block5(x)
|
|
x = self.block6(x)
|
|
x = self.block7(x)
|
|
x = self.block8(x)
|
|
x = self.block9(x)
|
|
x = self.block10(x)
|
|
x = self.block11(x)
|
|
x = self.block12(x)
|
|
x = self.block13(x)
|
|
x = self.block14(x)
|
|
x = self.block15(x)
|
|
x = self.block16(x)
|
|
x = self.block17(x)
|
|
x = self.block18(x)
|
|
x = self.block19(x)
|
|
x = self.block20(x)
|
|
|
|
x = self.conv3(x)
|
|
|
|
x = self.conv4(x)
|
|
|
|
x = self.conv5(x)
|
|
return low_featrue_layer, x
|
|
|
|
|
|
def load_url(url, model_dir='./model_data', map_location=None):
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir)
|
|
filename = url.split('/')[-1]
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if os.path.exists(cached_file):
|
|
return torch.load(cached_file, map_location=map_location)
|
|
else:
|
|
return model_zoo.load_url(url, model_dir=model_dir)
|
|
|
|
|
|
def xception(pretrained=True, downsample_factor=16):
|
|
model = Xception(downsample_factor=downsample_factor)
|
|
if pretrained:
|
|
model.load_state_dict(load_url(
|
|
'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth'
|
|
),
|
|
strict=False)
|
|
return model
|