ai-station-code/dimaoshibie/nets/backbone.py

545 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import math
import warnings
import numpy as np
from functools import partial
import torch
import torch.nn as nn
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
#--------------------------------------#
# Gelu激活函数的实现
# 利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):
def __init__(self):
super(GELU, self).__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
#--------------------------------------------------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key这一步可以通俗的理解为利用查询向量去查询序列的特征获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value这一步可以通俗的理解为将序列每个部分的重要程度重新施加到序列的值上去。
#
# 在segformer中为了减少计算量首先对特征图进行了浓缩所有特征层都压缩到原图的1/32。
# 当输入图片为512, 512时Block1的特征图为128, 128此时就先将特征层压缩为16, 16。
# 在Block1的Attention模块中相当于将8x8个特征点进行特征浓缩浓缩为一个特征点。
# 然后利用128x128个查询向量对16x16个键向量与值向量进行查询。尽管键向量与值向量的数量较少但因为查询向量的不同依然可以获得不同的输出。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
# bs, 16384, 32 => bs, 16384, 32 => bs, 16384, 8, 4 => bs, 8, 16384, 4
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
# bs, 16384, 32 => bs, 32, 128, 128
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
# bs, 32, 128, 128 => bs, 32, 16, 16 => bs, 256, 32
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
# bs, 256, 32 => bs, 256, 64 => bs, 256, 2, 8, 4 => 2, bs, 8, 256, 4
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
# bs, 8, 16384, 4 @ bs, 8, 4, 256 => bs, 8, 16384, 256
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# bs, 8, 16384, 256 @ bs, 8, 256, 4 => bs, 8, 16384, 4 => bs, 16384, 32
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# bs, 16384, 32 => bs, 16384, 32
x = self.proj(x)
x = self.proj_drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio
)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class MixVisionTransformer(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, embed_dims=[32, 64, 160, 256],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.num_classes = num_classes
self.depths = depths
#----------------------------------#
# Transformer模块共有四个部分
#----------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
#----------------------------------#
# block1
#----------------------------------#
#-----------------------------------------------#
# 对输入图像进行分区,并下采样
# 512, 512, 3 => 128, 128, 32 => 16384, 32
#-----------------------------------------------#
self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0])
#-----------------------------------------------#
# 利用transformer模块进行特征提取
# 16384, 32 => 16384, 32
#-----------------------------------------------#
cur = 0
self.block1 = nn.ModuleList(
[
Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]
)
for i in range(depths[0])
]
)
self.norm1 = norm_layer(embed_dims[0])
#----------------------------------#
# block2
#----------------------------------#
#-----------------------------------------------#
# 对输入图像进行分区,并下采样
# 128, 128, 32 => 64, 64, 64 => 4096, 64
#-----------------------------------------------#
self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
#-----------------------------------------------#
# 利用transformer模块进行特征提取
# 4096, 64 => 4096, 64
#-----------------------------------------------#
cur += depths[0]
self.block2 = nn.ModuleList(
[
Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1]
)
for i in range(depths[1])
]
)
self.norm2 = norm_layer(embed_dims[1])
#----------------------------------#
# block3
#----------------------------------#
#-----------------------------------------------#
# 对输入图像进行分区,并下采样
# 64, 64, 64 => 32, 32, 160 => 1024, 160
#-----------------------------------------------#
self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
#-----------------------------------------------#
# 利用transformer模块进行特征提取
# 1024, 160 => 1024, 160
#-----------------------------------------------#
cur += depths[1]
self.block3 = nn.ModuleList(
[
Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2]
)
for i in range(depths[2])
]
)
self.norm3 = norm_layer(embed_dims[2])
#----------------------------------#
# block4
#----------------------------------#
#-----------------------------------------------#
# 对输入图像进行分区,并下采样
# 32, 32, 160 => 16, 16, 256 => 256, 256
#-----------------------------------------------#
self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
#-----------------------------------------------#
# 利用transformer模块进行特征提取
# 256, 256 => 256, 256
#-----------------------------------------------#
cur += depths[2]
self.block4 = nn.ModuleList(
[
Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3]
)
for i in range(depths[3])
]
)
self.norm4 = norm_layer(embed_dims[3])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
B = x.shape[0]
outs = []
#----------------------------------#
# block1
#----------------------------------#
x, H, W = self.patch_embed1.forward(x)
for i, blk in enumerate(self.block1):
x = blk.forward(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
#----------------------------------#
# block2
#----------------------------------#
x, H, W = self.patch_embed2.forward(x)
for i, blk in enumerate(self.block2):
x = blk.forward(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
#----------------------------------#
# block3
#----------------------------------#
x, H, W = self.patch_embed3.forward(x)
for i, blk in enumerate(self.block3):
x = blk.forward(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
#----------------------------------#
# block4
#----------------------------------#
x, H, W = self.patch_embed4.forward(x)
for i, blk in enumerate(self.block4):
x = blk.forward(x, H, W)
x = self.norm4(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
class mit_b0(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b0, self).__init__(
embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b0_backbone_weights.pth"), strict=False)
class mit_b1(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b1, self).__init__(
embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b1_backbone_weights.pth"), strict=False)
class mit_b2(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b2, self).__init__(
embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b2_backbone_weights.pth"), strict=False)
class mit_b3(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b3, self).__init__(
embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b3_backbone_weights.pth"), strict=False)
class mit_b4(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b4, self).__init__(
embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b4_backbone_weights.pth"), strict=False)
class mit_b5(MixVisionTransformer):
def __init__(self, pretrained = False):
super(mit_b5, self).__init__(
embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
if pretrained:
print("Load backbone weights")
self.load_state_dict(torch.load("model_data/segformer_b5_backbone_weights.pth"), strict=False)