renewable_eva/nets/xception.py

299 lines
10 KiB
Python

import math
import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
bn_mom = 0.0003
class SeparableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=False,
activate_first=True,
inplace=True):
super(SeparableConv2d, self).__init__()
self.relu0 = nn.ReLU(inplace=inplace)
self.depthwise = nn.Conv2d(in_channels,
in_channels,
kernel_size,
stride,
padding,
dilation,
groups=in_channels,
bias=bias)
self.bn1 = nn.BatchNorm2d(in_channels, momentum=bn_mom)
self.relu1 = nn.ReLU(inplace=True)
self.pointwise = nn.Conv2d(in_channels,
out_channels,
1,
1,
0,
1,
1,
bias=bias)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=bn_mom)
self.relu2 = nn.ReLU(inplace=True)
self.activate_first = activate_first
def forward(self, x):
if self.activate_first:
x = self.relu0(x)
x = self.depthwise(x)
x = self.bn1(x)
if not self.activate_first:
x = self.relu1(x)
x = self.pointwise(x)
x = self.bn2(x)
if not self.activate_first:
x = self.relu2(x)
return x
class Block(nn.Module):
def __init__(self,
in_filters,
out_filters,
strides=1,
atrous=None,
grow_first=True,
activate_first=True,
inplace=True):
super(Block, self).__init__()
if atrous is None:
atrous = [1] * 3
elif isinstance(atrous, int):
atrous_list = [atrous] * 3
atrous = atrous_list
idx = 0
self.head_relu = True
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters,
out_filters,
1,
stride=strides,
bias=False)
self.skipbn = nn.BatchNorm2d(out_filters, momentum=bn_mom)
self.head_relu = False
else:
self.skip = None
self.hook_layer = None
if grow_first:
filters = out_filters
else:
filters = in_filters
self.sepconv1 = SeparableConv2d(in_filters,
filters,
3,
stride=1,
padding=1 * atrous[0],
dilation=atrous[0],
bias=False,
activate_first=activate_first,
inplace=self.head_relu)
self.sepconv2 = SeparableConv2d(filters,
out_filters,
3,
stride=1,
padding=1 * atrous[1],
dilation=atrous[1],
bias=False,
activate_first=activate_first)
self.sepconv3 = SeparableConv2d(out_filters,
out_filters,
3,
stride=strides,
padding=1 * atrous[2],
dilation=atrous[2],
bias=False,
activate_first=activate_first,
inplace=inplace)
def forward(self, inp):
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x = self.sepconv1(inp)
x = self.sepconv2(x)
self.hook_layer = x
x = self.sepconv3(x)
x += skip
return x
class Xception(nn.Module):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self, downsample_factor):
""" Constructor
Args:
num_classes: number of classes
"""
super(Xception, self).__init__()
stride_list = None
if downsample_factor == 8:
stride_list = [2, 1, 1]
elif downsample_factor == 16:
stride_list = [2, 2, 1]
else:
raise ValueError(
'xception.py: output stride=%d is not supported.' % os)
self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(32, momentum=bn_mom)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=bn_mom)
#do relu here
self.block1 = Block(64, 128, 2)
self.block2 = Block(128, 256, stride_list[0], inplace=False)
self.block3 = Block(256, 728, stride_list[1])
rate = 16 // downsample_factor
self.block4 = Block(728, 728, 1, atrous=rate)
self.block5 = Block(728, 728, 1, atrous=rate)
self.block6 = Block(728, 728, 1, atrous=rate)
self.block7 = Block(728, 728, 1, atrous=rate)
self.block8 = Block(728, 728, 1, atrous=rate)
self.block9 = Block(728, 728, 1, atrous=rate)
self.block10 = Block(728, 728, 1, atrous=rate)
self.block11 = Block(728, 728, 1, atrous=rate)
self.block12 = Block(728, 728, 1, atrous=rate)
self.block13 = Block(728, 728, 1, atrous=rate)
self.block14 = Block(728, 728, 1, atrous=rate)
self.block15 = Block(728, 728, 1, atrous=rate)
self.block16 = Block(728,
728,
1,
atrous=[1 * rate, 1 * rate, 1 * rate])
self.block17 = Block(728,
728,
1,
atrous=[1 * rate, 1 * rate, 1 * rate])
self.block18 = Block(728,
728,
1,
atrous=[1 * rate, 1 * rate, 1 * rate])
self.block19 = Block(728,
728,
1,
atrous=[1 * rate, 1 * rate, 1 * rate])
self.block20 = Block(728,
1024,
stride_list[2],
atrous=rate,
grow_first=False)
self.conv3 = SeparableConv2d(1024,
1536,
3,
1,
1 * rate,
dilation=rate,
activate_first=False)
self.conv4 = SeparableConv2d(1536,
1536,
3,
1,
1 * rate,
dilation=rate,
activate_first=False)
self.conv5 = SeparableConv2d(1536,
2048,
3,
1,
1 * rate,
dilation=rate,
activate_first=False)
self.layers = []
#------- init weights --------
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
#-----------------------------
def forward(self, input):
self.layers = []
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
x = self.block2(x)
low_featrue_layer = self.block2.hook_layer
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.block13(x)
x = self.block14(x)
x = self.block15(x)
x = self.block16(x)
x = self.block17(x)
x = self.block18(x)
x = self.block19(x)
x = self.block20(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return low_featrue_layer, x
def load_url(url, model_dir='./model_data', map_location=None):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if os.path.exists(cached_file):
return torch.load(cached_file, map_location=map_location)
else:
return model_zoo.load_url(url, model_dir=model_dir)
def xception(pretrained=True, downsample_factor=16):
model = Xception(downsample_factor=downsample_factor)
if pretrained:
model.load_state_dict(load_url(
'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth'
),
strict=False)
return model