# --------------------------------------------------------------- # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. # # This work is licensed under the NVIDIA Source Code License # --------------------------------------------------------------- import numpy as np import torch.nn as nn import torch from torch.nn import Softmax import torch.nn.functional as F #深度可分离卷积基本模块 class conv_dw(nn.Module): def __init__(self,inp, oup, stride = 1): super(conv_dw, self).__init__() self.basedw=nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) def forward(self,x): return self.basedw(x) class DepwithDoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv_dw = nn.Sequential( conv_dw(in_channels,mid_channels), conv_dw(mid_channels,out_channels) ) def forward(self,x): return self.double_conv_dw(x) class up_fusionblock(nn.Module): def __init__(self, in_channels, out_channels): super(up_fusionblock, self).__init__() self.in_channels= in_channels self.out_channels=out_channels self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.cbr=nn.Sequential( #最原始版本 nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) # # # self.cbr = nn.Sequential( # # # nn.Conv2d(in_channels, out_channels, 1), # # # nn.BatchNorm2d(out_channels), # # # nn.ReLU(inplace=True) # # # ) # self.cbr=DepwithDoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) return self.cbr(x) '''原版cab''' class CABy(nn.Module): def __init__(self, in_channels, out_channels): super(CABy, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x1, x2 = x # high, low x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False) x = torch.cat([x1,x2],dim=1) x = self.global_pooling(x) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x2 = x * x2 res = x2 + x1 return res class CAB(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CAB, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, high,low): high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False) x0 = torch.cat([high,low],dim=1) x0=self.cbr(x0) x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class CAM(nn.Module): def __init__(self, in_channels,ratio=8): super(CAM, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x0): x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class PAM_Module(nn.Module): """ Position attention module""" #Ref from SAGAN def __init__(self, in_dim): super(PAM_Module, self).__init__() self.chanel_in = in_dim self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X (HxW) X (HxW) """ m_batchsize, C, height, width = x.size() proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) energy = torch.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(m_batchsize, C, height, width) out = self.gamma*out + x return out class PAM(nn.Module): """ Position attention module""" #Ref from SAGAN def __init__(self, in_channels): super(PAM, self).__init__() inter_channels = in_channels // 4 self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.sa = PAM_Module(inter_channels) def forward(self,x): feat1 = self.conv5a(x) # 降维4倍维度变成512 sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64 return sa_feat def INF(B,H,W): return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) class CrissCrossAttention(nn.Module): """ Criss-Cross Attention Module""" def __init__(self, in_dim): super(CrissCrossAttention,self).__init__() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): '''我大概是明白了应该是自注意力在h和w方向拆开分别做qkv attention和v也分别相乘再分开''' m_batchsize, _, height, width = x.size() #64 32 32 proj_query = self.query_conv(x) proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) proj_key = self.key_conv(x) proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) proj_value = self.value_conv(x) proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width) concate = self.softmax(torch.cat([energy_H, energy_W], 3)) att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) #print(concate) #print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) #print(out_H.size(),out_W.size()) return self.gamma*(out_H + out_W) + x class CARB(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CARB, self).__init__() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.global_pooling = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x0=self.cbr(x) x = self.global_pooling(x0) x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x class unetCAB0Decoder(nn.Module): def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]): super(unetCAB0Decoder, self).__init__() '''b1的中间层输出: 64 128 320 512 ''' self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2]) self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1]) self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 '''x:32被下采样的 16*16*512 组成:cbr(conv1*1 降维)+ 上采样+concate+ ''' x4,x8,x16,x32=inputlist x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetCAM0Decoder(nn.Module): def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]): super(unetCAM0Decoder, self).__init__() '''b1的中间层输出: 64 128 320 512 1/4 1/8 1/16 1/32 ''' self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.cam32 = CAM(512) self.cam16 = CAM(320) self.cam8 = CAM(128) self.cam4 = CAM(64) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 '''x:32被下采样的 16*16*512 组成:cbr(conv1*1 降维)+ 上采样+concate+ ''' x4,x8,x16,x32=inputlist x4=self.cam4(x4) x8=self.cam4(x8) x16=self.cam4(x16) x32=self.cam4(x32) x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetDecoder(nn.Module): def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]): super(unetDecoder, self).__init__() # ''' # res50的中间层输出: # 256 512 1024 2048 # 1/4(stage1) 1/8 1/16 1/32 # # decoder如果什么都不加的话 最后的输出中间来个维度转换1*1cbr()---512 # 256 512 1024 512 # 1/4(stage1) 1/8 1/16 1/32 # # in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512] # # # # ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters # self.trans=nn.Sequential( # nn.Conv2d(2048,512,3,1,1), # nn.BatchNorm2d(512), # nn.ReLU() # # ) # self.trans = nn.Sequential( # nn.Conv2d(2048, 512, 1, 1), # nn.BatchNorm2d(512), # nn.ReLU() # # ) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist # x32=self.trans(x32) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamDecoder(nn.Module): def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamDecoder, self).__init__() ''' 256 512 1024 512 ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters # self.trans = nn.Sequential( # nn.Conv2d(2048, 512, 1, 1), # nn.BatchNorm2d(512), # nn.ReLU() # # ) self.pam = PAM(in_channels=2048) #他出来是1/4的通道数 # self.pam = PAM(in_channels=512) #他出来是1/4的通道数 # self.pam = CrissCrossAttention(in_dim=512) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist # x32 = self.pam(self.trans(x32)) x32 = self.pam(x32) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamCARBDecoder(nn.Module): def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]): super(unetpamCARBDecoder, self).__init__() self.nclas=nclass self.finnal_channel=512 self.in_filters = in_filters self.out_filters = out_filters self.pam=PAM(in_channels=2048) self.carb32=CARB2(512,out_channels=512) self.carb16 = CARB2(in_channels=1024,out_channels=256) self.carb8 = CARB2(in_channels=512,out_channels=128) self.carb4 = CARB2(in_channels=256,out_channels=64) ''' resnet 256 512 1024 2048 pam后变成了256 512 1024 512 carb--1.>128 256 512 512 carb--2..>64 128 256 512 1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256] 2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256] ''' self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2]) self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1]) self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0]) self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self,inputlist): #512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4,x8,x16,x32=inputlist x32=self.carb32(self.pam(x32)) x4=self.carb4(x4) x8=self.carb8(x8) x16=self.carb16(x16) x16= self.up_concat3(x32,x16) x8= self.up_concat2(x16,x8) x4= self.up_concat1(x8,x4) x4=self.classifer(x4) out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamDecoderzuhe(nn.Module): def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamDecoderzuhe, self).__init__() #in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512] ''' 256 512 1024 512 ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.pam = PAM(in_channels=2048) self.carb32 = CARB(512, out_channels=512) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32 = self.carb32(self.pam(x32)) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class unetpamcamDecoder(nn.Module): def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetpamcamDecoder, self).__init__() self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.pam = PAM(in_channels=2048) self.cam32 = CAM(512) self.cam16 = CAM(in_channels=1024,) self.cam8 = CAM(in_channels=512,) self.cam4 = CAM(in_channels=256,) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32 = self.cam32(self.pam(x32)) x4 = self.cam4(x4) x8 = self.cam8(x8) x16 = self.cam16(x16) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out class CARB2(nn.Module): def __init__(self, in_channels,out_channels,ratio=8): super(CARB2, self).__init__() self.cbr = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.global_avgpooling = nn.AdaptiveAvgPool2d(1) self.global_maxpooling = nn.AdaptiveMaxPool2d(1) self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0) self.sigmod = nn.Sigmoid() def forward(self, x): x0=self.cbr(x) xavg = self.global_avgpooling(x0) xmax = self.global_maxpooling(x0) x=xavg+xmax x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.sigmod(x) x = x * x0 return x def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False), norm_layer(out_channels), nn.ReLU(inplace=True) ) ############________________________________________________ class AtrousSeparableConvolution(nn.Module): """ Atrous Separable Convolution """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): super(AtrousSeparableConvolution, self).__init__() self.body = nn.Sequential( # Separable Conv nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels), # PointWise Conv nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), ) self._init_weight() def forward(self, x): return self.body(x) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ] super(ASPPConv, self).__init__(*modules) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self, x): size = x.shape[-2:] x = super(ASPPPooling, self).forward(x) return F.interpolate(x, size=size, mode='bilinear', align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, atrous_rates): super(ASPP, self).__init__() out_channels = 256 modules = [] modules.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))) rate1, rate2, rate3 = tuple(atrous_rates) modules.append(ASPPConv(in_channels, out_channels, rate1)) modules.append(ASPPConv(in_channels, out_channels, rate2)) modules.append(ASPPConv(in_channels, out_channels, rate3)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout(0.1), ) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res) class unetasppDecoder(nn.Module): def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]): super(unetasppDecoder, self).__init__() ''' ASPP出来是256 256 512 1024 512 ''' self.nclas = nclass # self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters # self.trans = nn.Sequential( # nn.Conv2d(2048, 512, 1, 1), # nn.BatchNorm2d(512), # nn.ReLU() # # ) '''x4-x16:256 512 1024 ''' self.projectx4 = nn.Sequential( nn.Conv2d(256, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), ) self.projectx8 = nn.Sequential( nn.Conv2d(512, 256, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) self.projectx16 = nn.Sequential( nn.Conv2d(1024, 512, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) self.pam = ASPP(in_channels=2048,atrous_rates=[5, 7, 11]) #他出来是1/4的通道数 self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist # x32 = self.pam(self.trans(x32)) x32 = self.pam(x32) x16=self.projectx16(x16) x8=self.projectx8(x8) x4=self.projectx4(x4) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out '''另一种融合 concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维 ''' ######################################################################### '''关于这个manet我最后再给你一次机会''' class unetDecoder2(nn.Module): def __init__(self, nclass=2, in_filters = [384, 768, 1536], out_filters = [128, 256, 512]): super(unetDecoder2, self).__init__() # ''' # res50的中间层输出: # 256 512 1024 2048 # 1/4(stage1) 1/8 1/16 1/32 #这四个统一加一个1*1卷积进行降维降一半 # 128 256 512 1024 # in_filters = [384, 768, 1536], out_filters = [128, 256, 512] # # # ''' self.nclas = nclass self.finnal_channel = 512 self.in_filters = in_filters self.out_filters = out_filters self.transx32=nn.Sequential( nn.Conv2d(2048,1024,1,1), nn.BatchNorm2d(1024), nn.ReLU() ) self.transx16 = nn.Sequential( nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512), nn.ReLU() ) self.transx8 = nn.Sequential( nn.Conv2d(512, 256, 1, 1), nn.BatchNorm2d(256), nn.ReLU() ) self.transx4 = nn.Sequential( nn.Conv2d(256, 128, 1, 1), nn.BatchNorm2d(128), nn.ReLU() ) self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2]) self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1]) self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0]) self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1) self._init_weight() def _init_weight(self): print("decoder从初始化执行") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, inputlist): # 512为例 ''' mit x4:128 128 64 x8:64 64 128 x16:32 32 320 x32: 16 16 512 in_filters=[192, 448, 832],out_filters = [64, 128, 320]) resnet x4:256 128 128 x8:512 64 128 x16:1024 32 320 x32: 2048 16 512 ''' x4, x8, x16, x32 = inputlist x32=self.transx32(x32) x16=self.transx16(x16) x8=self.transx8(x8) x4=self.transx4(x4) x16 = self.up_concat3(x32, x16) x8 = self.up_concat2(x16, x8) x4 = self.up_concat1(x8, x4) x4 = self.classifer(x4) out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) return out