# --------------------------------------------------------------- # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. # # This work is licensed under the NVIDIA Source Code License # --------------------------------------------------------------- '''配合下采样不是32倍数的''' import numpy as np import torch.nn as nn import torch import torch.nn.functional as F '' #深度可分离卷积基本模块 class conv_dw(nn.Module): def __init__(self,inp, oup, stride = 1): super(conv_dw, self).__init__() self.basedw=nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) def forward(self,x): return self.basedw(x) class DepwithDoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv_dw = nn.Sequential( conv_dw(in_channels,mid_channels), conv_dw(mid_channels,out_channels) ) def forward(self,x): return self.double_conv_dw(x) class up_fusionblock(nn.Module): def __init__(self, in_channels, out_channels): super(up_fusionblock, self).__init__() self.in_channels= in_channels self.out_channels=out_channels self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.cbr=nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x1, x2): x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) return self.cbr(x) class fusionblock(nn.Module): def __init__(self, in_channels, out_channels): super(fusionblock, self).__init__() self.in_channels= in_channels self.out_channels=out_channels self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.cbr=nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x1, x2): x = torch.cat([x2, x1], dim=1) return self.cbr(x) '''原版cab''' class CABy(nn.Module): def __init__(self, in_channels, out_channels): super(CABy, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x1, x2 = x # high, low x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False) x = torch.cat([x1,x2],dim=1) x = self.global_pooling(x) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x2 = x * x2 res = x2 + x1 return res class CAB(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CAB, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, high,low): high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False) x0 = torch.cat([high,low],dim=1) x0=self.cbr(x0) x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class CAM(nn.Module): def __init__(self, in_channels,ratio=8): super(CAM, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x0): x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class PAM_Module(nn.Module): """ Position attention module""" #Ref from SAGAN def __init__(self, in_dim): super(PAM_Module, self).__init__() self.chanel_in = in_dim self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X (HxW) X (HxW) """ m_batchsize, C, height, width = x.size() proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) energy = torch.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(m_batchsize, C, height, width) out = self.gamma*out + x return out class PAM(nn.Module): """ Position attention module""" #Ref from SAGAN def __init__(self, in_channels): super(PAM, self).__init__() inter_channels = in_channels // 4 self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.sa = PAM_Module(inter_channels) def forward(self,x): feat1 = self.conv5a(x) # 降维4倍维度变成512 sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64 return sa_feat class CARB(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CARB, self).__init__() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x0=self.cbr(x) x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class unetCAB0Decoder(nn.Module): def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]): super(unetCAB0Decoder, self).__init__() '''b1的中间层输出: 64 128 320 512 ''' self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2]) self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1]) self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 '''x:32被下采样的 16*16*512 组成:cbr(conv1*1 降维)+ 上采样+concate+ ''' x4,x8,x16,x32=inputlist x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetCAM0Decoder(nn.Module): def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]): super(unetCAM0Decoder, self).__init__() '''b1的中间层输出: 64 128 320 512 1/4 1/8 1/16 1/32 ''' self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.cam32 = CAM(512) self.cam16 = CAM(320) self.cam8 = CAM(128) self.cam4 = CAM(64) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 '''x:32被下采样的 16*16*512 组成:cbr(conv1*1 降维)+ 上采样+concate+ ''' x4,x8,x16,x32=inputlist x4=self.cam4(x4) x8=self.cam4(x8) x16=self.cam4(x16) x32=self.cam4(x32) x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetDecoder(nn.Module): def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]): super(unetDecoder, self).__init__() '''mitb1的中间层输出: 64 128 320 512 renet50 256 512 1024 2048 ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamDecoder(nn.Module): def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamDecoder, self).__init__() #in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512] ''' 256 512 1024 512 ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.pam = PAM(in_channels=2048) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32 = self.pam(x32) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamCARBDecoder(nn.Module): def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]): super(unetpamCARBDecoder, self).__init__() self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.pam=PAM(in_channels=2048) self.carb32=CARB2(512,out_channels=512) self.carb16 = CARB2(in_channels=1024,out_channels=256) self.carb8 = CARB2(in_channels=512,out_channels=128) self.carb4 = CARB2(in_channels=256,out_channels=64) ''' resnet 256 512 1024 2048 pam后变成了256 512 1024 512 carb--1.>128 256 512 512 carb--2..>64 128 256 512 1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256] 2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256] ''' self.up_concat3=fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2]) self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1]) self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4,x8,x16,x32=inputlist x32=self.carb32(self.pam(x32)) x4=self.carb4(x4) x8=self.carb8(x8) x16=self.carb16(x16) x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamDecoderzuhe(nn.Module): def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamDecoderzuhe, self).__init__() #in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512] ''' 256 512 1024 512 ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.pam = PAM(in_channels=2048) self.carb32 = CARB(512, out_channels=512) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32 = self.carb32(self.pam(x32)) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamcamDecoder(nn.Module): def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamcamDecoder, self).__init__() self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.pam = PAM(in_channels=2048) self.cam32 = CAM(512) self.cam16 = CAM(in_channels=1024,) self.cam8 = CAM(in_channels=512,) self.cam4 = CAM(in_channels=256,) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32 = self.cam32(self.pam(x32)) x4 = self.cam4(x4) x8 = self.cam8(x8) x16 = self.cam16(x16) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class CARB2(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CARB2, self).__init__() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.global_avgpooling = nn.AdaptiveAvgPool2d(1) self.global_maxpooling = nn.AdaptiveMaxPool2d(1) self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x0=self.cbr(x) xavg = self.global_avgpooling(x0) xmax = self.global_maxpooling(x0) x=xavg+xmax x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False), norm_layer(out_channels), nn.ReLU(inplace=True) ) '''另一种融合 concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维 '''