renewable_eva/nets/deeplabv3_plus.py

258 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2
class MobileNetV2(nn.Module):
def __init__(self, downsample_factor=8, pretrained=True):
super(MobileNetV2, self).__init__()
from functools import partial
model = mobilenetv2(pretrained)
self.features = model.features[:-1]
self.total_idx = len(self.features)
self.down_idx = [2, 4, 7, 14]
if downsample_factor == 8:
for i in range(self.down_idx[-2], self.down_idx[-1]):
self.features[i].apply(partial(self._nostride_dilate,
dilate=2))
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(partial(self._nostride_dilate,
dilate=4))
elif downsample_factor == 16:
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(partial(self._nostride_dilate,
dilate=2))
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate // 2, dilate // 2)
m.padding = (dilate // 2, dilate // 2)
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x):
low_level_features = self.features[:4](x)
x = self.features[4:](low_level_features)
return low_level_features, x
#-----------------------------------------#
# ASPP特征提取模块
# 利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
"""
卷积操作使用了膨胀率 dilation=6 * rate这意味着卷积核内的采样点之间有 6 个像素的间隔。
为了保持输出特征图的尺寸与输入数据相同,填充参数 padding 被设置为 padding=6 * rate。
这样,卷积操作将在输入数据上以膨胀率为 6 的间隔进行卷积,并且在输出特征图的边缘周围填充 6 个像素的零,以确保输出特征图的尺寸不会缩小。
bn_mom 在这段代码中是批量归一化层Batch Normalization的动量参数momentum
在 PyTorch 中的批量归一化层的实现中,动量参数控制了均值和方差的移动平均值的更新速度
dilation 是卷积操作的膨胀率参数
"""
class ASPP(nn.Module):
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
1,
1,
padding=0,
dilation=rate,
bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
3,
1,
padding=6 * rate,
dilation=6 * rate,
bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
3,
1,
padding=12 * rate,
dilation=12 * rate,
bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
3,
1,
padding=18 * rate,
dilation=18 * rate,
bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
# branch5是一个池化过程
self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch5_relu = nn.ReLU(inplace=True)
self.conv_cat = nn.Sequential(
nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
def forward(self, x):
[b, c, row, col] = x.size()
#-----------------------------------------#
# 一共五个分支
#-----------------------------------------#
conv1x1 = self.branch1(x)
conv3x3_1 = self.branch2(x)
conv3x3_2 = self.branch3(x)
conv3x3_3 = self.branch4(x)
#-----------------------------------------#
# 第五个分支,全局平均池化+卷积
#-----------------------------------------#
global_feature = torch.mean(x, 2, True)
global_feature = torch.mean(global_feature, 3, True) #两次平均池化
global_feature = self.branch5_conv(global_feature)
global_feature = self.branch5_bn(global_feature)
global_feature = self.branch5_relu(global_feature)# 通道的整合
global_feature = F.interpolate(global_feature, (row, col), None,
'bilinear', True)
"""
这一行代码首先使用 torch.mean 函数计算了 x 沿第 2 维度(通常是垂直方向)的平均值。
True 参数表示要保持维度的数量,即在结果中保留该维度。这将生成一个张量
其形状为 [batch_size, num_channels, 1, width]
其中 batch_size 是批量大小num_channels 是通道数width 是特征图的宽度,高度为 1。
global_feature = torch.mean(global_feature, 3, True)
同样使用 torch.mean 函数,计算了上一步结果中的 global_feature 沿第 3 维度(通常是水平方向)的平均值,
并再次保持维度的数量。这将生成一个形状为 [batch_size, num_channels, 1, 1] 的张量,
其中每个通道的值都是整个特征图在该通道上的平均值。
这个过程的结果是对整个特征图的每个通道执行全局平均池化,
最终生成一个形状为 [batch_size, num_channels, 1, 1] 的全局平均特征向量。
将全局平均特征 global_feature 调整为指定的目标尺寸 (row, col)
"""
#-----------------------------------------#
# 将五个分支的内容堆叠起来
# 然后1x1卷积整合特征。
#-----------------------------------------#
feature_cat = torch.cat(
[conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
result = self.conv_cat(feature_cat) #1*1卷积 这个就是哪个绿色的
return result
# ASPP就是使用不同膨胀率的膨胀卷积对特征进行提取
class DeepLab(nn.Module):
def __init__(self,
num_classes,
backbone="mobilenet",
pretrained=True,
downsample_factor=16):
super(DeepLab, self).__init__()
if backbone == "xception":
#----------------------------------#
# 获得两个特征层
# 浅层特征 [128,128,256]
# 主干部分 [30,30,2048]
#----------------------------------#
self.backbone = xception(downsample_factor=downsample_factor,
pretrained=pretrained)
in_channels = 2048
low_level_channels = 256
elif backbone == "mobilenet":
#----------------------------------#
# 获得两个特征层
# 浅层特征 [128,128,24]
# 主干部分 [30,30,320]
#----------------------------------#
self.backbone = MobileNetV2(downsample_factor=downsample_factor,
pretrained=pretrained)
in_channels = 320
low_level_channels = 24
else:
raise ValueError(
'Unsupported backbone - `{}`, Use mobilenet, xception.'.format(
backbone))
#-----------------------------------------#
# ASPP特征提取模块
# 利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
self.aspp = ASPP(dim_in=in_channels,
dim_out=256,
rate=16 // downsample_factor)
#----------------------------------#
# 浅层特征边
#----------------------------------#
self.shortcut_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True))
self.cat_conv = nn.Sequential(
nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(256, 256, 3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)
"""
在每个训练批次中,以一定的概率(通常在 0.2 到 0.5 之间)随机选择一些神经元,并将它们的输出置零。
这表示在每次前向传播中,只有部分神经元的输出会被传递到下一层,而其他神经元的输出被设置为零。
在每次反向传播中,只有那些没有被置零的神经元才会更新其权重。这意味着每个神经元都有机会被训练,
而不是过度依赖于特定的神经元。
"""
def forward(self, x):
H, W = x.size(2), x.size(3)
#-----------------------------------------#
# 获得两个特征层
# low_level_features: 浅层特征-进行卷积处理
# x : 主干部分-利用ASPP结构进行加强特征提取
#-----------------------------------------#
low_level_features, x = self.backbone(x)
x = self.aspp(x)
low_level_features = self.shortcut_conv(low_level_features)
#-----------------------------------------#
# 将加强特征边上采样 绿色的模块
# 与浅层特征堆叠后利用卷积进行特征提取
#-----------------------------------------#
x = F.interpolate(x,
size=(low_level_features.size(2),
low_level_features.size(3)),
mode='bilinear',#使用双线性插值进行的上采样操作
align_corners=True) #上采样
x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
x = self.cls_conv(x)
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
return x