258 lines
11 KiB
Python
258 lines
11 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from nets.xception import xception
|
||
from nets.mobilenetv2 import mobilenetv2
|
||
|
||
|
||
class MobileNetV2(nn.Module):
|
||
def __init__(self, downsample_factor=8, pretrained=True):
|
||
super(MobileNetV2, self).__init__()
|
||
from functools import partial
|
||
|
||
model = mobilenetv2(pretrained)
|
||
self.features = model.features[:-1]
|
||
|
||
self.total_idx = len(self.features)
|
||
self.down_idx = [2, 4, 7, 14]
|
||
|
||
if downsample_factor == 8:
|
||
for i in range(self.down_idx[-2], self.down_idx[-1]):
|
||
self.features[i].apply(partial(self._nostride_dilate,
|
||
dilate=2))
|
||
for i in range(self.down_idx[-1], self.total_idx):
|
||
self.features[i].apply(partial(self._nostride_dilate,
|
||
dilate=4))
|
||
elif downsample_factor == 16:
|
||
for i in range(self.down_idx[-1], self.total_idx):
|
||
self.features[i].apply(partial(self._nostride_dilate,
|
||
dilate=2))
|
||
|
||
def _nostride_dilate(self, m, dilate):
|
||
classname = m.__class__.__name__
|
||
if classname.find('Conv') != -1:
|
||
if m.stride == (2, 2):
|
||
m.stride = (1, 1)
|
||
if m.kernel_size == (3, 3):
|
||
m.dilation = (dilate // 2, dilate // 2)
|
||
m.padding = (dilate // 2, dilate // 2)
|
||
else:
|
||
if m.kernel_size == (3, 3):
|
||
m.dilation = (dilate, dilate)
|
||
m.padding = (dilate, dilate)
|
||
|
||
def forward(self, x):
|
||
low_level_features = self.features[:4](x)
|
||
x = self.features[4:](low_level_features)
|
||
return low_level_features, x
|
||
|
||
|
||
#-----------------------------------------#
|
||
# ASPP特征提取模块
|
||
# 利用不同膨胀率的膨胀卷积进行特征提取
|
||
#-----------------------------------------#
|
||
"""
|
||
卷积操作使用了膨胀率 dilation=6 * rate,这意味着卷积核内的采样点之间有 6 个像素的间隔。
|
||
为了保持输出特征图的尺寸与输入数据相同,填充参数 padding 被设置为 padding=6 * rate。
|
||
这样,卷积操作将在输入数据上以膨胀率为 6 的间隔进行卷积,并且在输出特征图的边缘周围填充 6 个像素的零,以确保输出特征图的尺寸不会缩小。
|
||
|
||
|
||
bn_mom 在这段代码中是批量归一化层(Batch Normalization)的动量参数(momentum)。
|
||
在 PyTorch 中的批量归一化层的实现中,动量参数控制了均值和方差的移动平均值的更新速度
|
||
|
||
dilation 是卷积操作的膨胀率参数
|
||
"""
|
||
class ASPP(nn.Module):
|
||
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
|
||
super(ASPP, self).__init__()
|
||
self.branch1 = nn.Sequential(
|
||
nn.Conv2d(dim_in,
|
||
dim_out,
|
||
1,
|
||
1,
|
||
padding=0,
|
||
dilation=rate,
|
||
bias=True),
|
||
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
self.branch2 = nn.Sequential(
|
||
nn.Conv2d(dim_in,
|
||
dim_out,
|
||
3,
|
||
1,
|
||
padding=6 * rate,
|
||
dilation=6 * rate,
|
||
bias=True),
|
||
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
self.branch3 = nn.Sequential(
|
||
nn.Conv2d(dim_in,
|
||
dim_out,
|
||
3,
|
||
1,
|
||
padding=12 * rate,
|
||
dilation=12 * rate,
|
||
bias=True),
|
||
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
self.branch4 = nn.Sequential(
|
||
nn.Conv2d(dim_in,
|
||
dim_out,
|
||
3,
|
||
1,
|
||
padding=18 * rate,
|
||
dilation=18 * rate,
|
||
bias=True),
|
||
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
# branch5是一个池化过程
|
||
self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
|
||
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
|
||
self.branch5_relu = nn.ReLU(inplace=True)
|
||
|
||
self.conv_cat = nn.Sequential(
|
||
nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
|
||
nn.BatchNorm2d(dim_out, momentum=bn_mom),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
|
||
def forward(self, x):
|
||
[b, c, row, col] = x.size()
|
||
#-----------------------------------------#
|
||
# 一共五个分支
|
||
#-----------------------------------------#
|
||
conv1x1 = self.branch1(x)
|
||
conv3x3_1 = self.branch2(x)
|
||
conv3x3_2 = self.branch3(x)
|
||
conv3x3_3 = self.branch4(x)
|
||
#-----------------------------------------#
|
||
# 第五个分支,全局平均池化+卷积
|
||
#-----------------------------------------#
|
||
global_feature = torch.mean(x, 2, True)
|
||
global_feature = torch.mean(global_feature, 3, True) #两次平均池化
|
||
global_feature = self.branch5_conv(global_feature)
|
||
global_feature = self.branch5_bn(global_feature)
|
||
global_feature = self.branch5_relu(global_feature)# 通道的整合
|
||
global_feature = F.interpolate(global_feature, (row, col), None,
|
||
'bilinear', True)
|
||
"""
|
||
这一行代码首先使用 torch.mean 函数计算了 x 沿第 2 维度(通常是垂直方向)的平均值。
|
||
True 参数表示要保持维度的数量,即在结果中保留该维度。这将生成一个张量
|
||
其形状为 [batch_size, num_channels, 1, width],
|
||
其中 batch_size 是批量大小,num_channels 是通道数,width 是特征图的宽度,高度为 1。
|
||
global_feature = torch.mean(global_feature, 3, True):
|
||
同样使用 torch.mean 函数,计算了上一步结果中的 global_feature 沿第 3 维度(通常是水平方向)的平均值,
|
||
并再次保持维度的数量。这将生成一个形状为 [batch_size, num_channels, 1, 1] 的张量,
|
||
其中每个通道的值都是整个特征图在该通道上的平均值。
|
||
|
||
这个过程的结果是对整个特征图的每个通道执行全局平均池化,
|
||
最终生成一个形状为 [batch_size, num_channels, 1, 1] 的全局平均特征向量。
|
||
|
||
将全局平均特征 global_feature 调整为指定的目标尺寸 (row, col)
|
||
"""
|
||
#-----------------------------------------#
|
||
# 将五个分支的内容堆叠起来
|
||
# 然后1x1卷积整合特征。
|
||
#-----------------------------------------#
|
||
feature_cat = torch.cat(
|
||
[conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
|
||
result = self.conv_cat(feature_cat) #1*1卷积 这个就是哪个绿色的
|
||
return result
|
||
# ASPP就是使用不同膨胀率的膨胀卷积对特征进行提取
|
||
|
||
class DeepLab(nn.Module):
|
||
def __init__(self,
|
||
num_classes,
|
||
backbone="mobilenet",
|
||
pretrained=True,
|
||
downsample_factor=16):
|
||
super(DeepLab, self).__init__()
|
||
if backbone == "xception":
|
||
#----------------------------------#
|
||
# 获得两个特征层
|
||
# 浅层特征 [128,128,256]
|
||
# 主干部分 [30,30,2048]
|
||
#----------------------------------#
|
||
self.backbone = xception(downsample_factor=downsample_factor,
|
||
pretrained=pretrained)
|
||
in_channels = 2048
|
||
low_level_channels = 256
|
||
elif backbone == "mobilenet":
|
||
#----------------------------------#
|
||
# 获得两个特征层
|
||
# 浅层特征 [128,128,24]
|
||
# 主干部分 [30,30,320]
|
||
#----------------------------------#
|
||
self.backbone = MobileNetV2(downsample_factor=downsample_factor,
|
||
pretrained=pretrained)
|
||
in_channels = 320
|
||
low_level_channels = 24
|
||
else:
|
||
raise ValueError(
|
||
'Unsupported backbone - `{}`, Use mobilenet, xception.'.format(
|
||
backbone))
|
||
|
||
#-----------------------------------------#
|
||
# ASPP特征提取模块
|
||
# 利用不同膨胀率的膨胀卷积进行特征提取
|
||
#-----------------------------------------#
|
||
self.aspp = ASPP(dim_in=in_channels,
|
||
dim_out=256,
|
||
rate=16 // downsample_factor)
|
||
|
||
#----------------------------------#
|
||
# 浅层特征边
|
||
#----------------------------------#
|
||
self.shortcut_conv = nn.Sequential(
|
||
nn.Conv2d(low_level_channels, 48, 1),
|
||
nn.BatchNorm2d(48),
|
||
nn.ReLU(inplace=True))
|
||
|
||
self.cat_conv = nn.Sequential(
|
||
nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1),
|
||
nn.BatchNorm2d(256),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(0.5),
|
||
|
||
nn.Conv2d(256, 256, 3, stride=1, padding=1),
|
||
nn.BatchNorm2d(256),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(0.1),
|
||
)
|
||
self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)
|
||
"""
|
||
在每个训练批次中,以一定的概率(通常在 0.2 到 0.5 之间)随机选择一些神经元,并将它们的输出置零。
|
||
这表示在每次前向传播中,只有部分神经元的输出会被传递到下一层,而其他神经元的输出被设置为零。
|
||
|
||
在每次反向传播中,只有那些没有被置零的神经元才会更新其权重。这意味着每个神经元都有机会被训练,
|
||
而不是过度依赖于特定的神经元。
|
||
"""
|
||
def forward(self, x):
|
||
H, W = x.size(2), x.size(3)
|
||
#-----------------------------------------#
|
||
# 获得两个特征层
|
||
# low_level_features: 浅层特征-进行卷积处理
|
||
# x : 主干部分-利用ASPP结构进行加强特征提取
|
||
#-----------------------------------------#
|
||
low_level_features, x = self.backbone(x)
|
||
x = self.aspp(x)
|
||
low_level_features = self.shortcut_conv(low_level_features)
|
||
|
||
#-----------------------------------------#
|
||
# 将加强特征边上采样 绿色的模块
|
||
# 与浅层特征堆叠后利用卷积进行特征提取
|
||
#-----------------------------------------#
|
||
x = F.interpolate(x,
|
||
size=(low_level_features.size(2),
|
||
low_level_features.size(3)),
|
||
mode='bilinear',#使用双线性插值进行的上采样操作
|
||
align_corners=True) #上采样
|
||
x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
|
||
x = self.cls_conv(x)
|
||
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
|
||
return x
|