31 lines
1.2 KiB
Python
31 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from nets.xception import xception
|
|
from nets.mobilenetv2 import mobilenetv2
|
|
|
|
class ASPP(nn.Module):
|
|
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
|
|
super(ASPP, self).__init__()
|
|
self.branch1 = nn.Sequential(
|
|
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
|
|
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.branch2 = nn.Sequential(
|
|
nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
|
|
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.branch3 = nn.Sequential(
|
|
nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
|
|
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.branch4 = nn.Sequential(
|
|
nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
|
|
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
self.branch5_conv = nn.Conv2d(dim_in, dim_out,1, 1, 0, bias=True),
|