716 lines
23 KiB
Python
716 lines
23 KiB
Python
|
# ---------------------------------------------------------------
|
|||
|
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
|||
|
#
|
|||
|
# This work is licensed under the NVIDIA Source Code License
|
|||
|
# ---------------------------------------------------------------
|
|||
|
|
|||
|
'''配合下采样不是32倍数的'''
|
|||
|
import numpy as np
|
|||
|
import torch.nn as nn
|
|||
|
import torch
|
|||
|
|
|||
|
import torch.nn.functional as F
|
|||
|
|
|||
|
''
|
|||
|
|
|||
|
|
|||
|
#深度可分离卷积基本模块
|
|||
|
class conv_dw(nn.Module):
|
|||
|
def __init__(self,inp, oup, stride = 1):
|
|||
|
super(conv_dw, self).__init__()
|
|||
|
self.basedw=nn.Sequential(
|
|||
|
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
|||
|
nn.BatchNorm2d(inp),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
|
|||
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
|||
|
nn.BatchNorm2d(oup),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
)
|
|||
|
def forward(self,x):
|
|||
|
return self.basedw(x)
|
|||
|
|
|||
|
class DepwithDoubleConv(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels, out_channels, mid_channels=None):
|
|||
|
super().__init__()
|
|||
|
if not mid_channels:
|
|||
|
mid_channels = out_channels
|
|||
|
self.double_conv_dw = nn.Sequential(
|
|||
|
conv_dw(in_channels,mid_channels),
|
|||
|
conv_dw(mid_channels,out_channels)
|
|||
|
)
|
|||
|
|
|||
|
def forward(self,x):
|
|||
|
return self.double_conv_dw(x)
|
|||
|
|
|||
|
|
|||
|
class up_fusionblock(nn.Module):
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(up_fusionblock, self).__init__()
|
|||
|
self.in_channels= in_channels
|
|||
|
self.out_channels=out_channels
|
|||
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
self.cbr=nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
def forward(self, x1, x2):
|
|||
|
x1 = self.up(x1)
|
|||
|
x = torch.cat([x2, x1], dim=1)
|
|||
|
return self.cbr(x)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class fusionblock(nn.Module):
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(fusionblock, self).__init__()
|
|||
|
self.in_channels= in_channels
|
|||
|
self.out_channels=out_channels
|
|||
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
self.cbr=nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
def forward(self, x1, x2):
|
|||
|
|
|||
|
x = torch.cat([x2, x1], dim=1)
|
|||
|
return self.cbr(x)
|
|||
|
|
|||
|
'''原版cab'''
|
|||
|
class CABy(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(CABy, self).__init__()
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x1, x2 = x # high, low
|
|||
|
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
|
|||
|
x = torch.cat([x1,x2],dim=1)
|
|||
|
x = self.global_pooling(x)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x2 = x * x2
|
|||
|
res = x2 + x1
|
|||
|
return res
|
|||
|
|
|||
|
class CAB(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CAB, self).__init__()
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, high,low):
|
|||
|
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
|
|||
|
x0 = torch.cat([high,low],dim=1)
|
|||
|
x0=self.cbr(x0)
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
class CAM(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels,ratio=8):
|
|||
|
super(CAM, self).__init__()
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
def forward(self, x0):
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
class PAM_Module(nn.Module):
|
|||
|
""" Position attention module"""
|
|||
|
#Ref from SAGAN
|
|||
|
def __init__(self, in_dim):
|
|||
|
super(PAM_Module, self).__init__()
|
|||
|
self.chanel_in = in_dim
|
|||
|
|
|||
|
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
|||
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|||
|
|
|||
|
self.softmax = nn.Softmax(dim=-1)
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
inputs :
|
|||
|
x : input feature maps( B X C X H X W)
|
|||
|
returns :
|
|||
|
out : attention value + input feature
|
|||
|
attention: B X (HxW) X (HxW)
|
|||
|
"""
|
|||
|
m_batchsize, C, height, width = x.size()
|
|||
|
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
|
|||
|
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
|
|||
|
energy = torch.bmm(proj_query, proj_key)
|
|||
|
attention = self.softmax(energy)
|
|||
|
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
|
|||
|
|
|||
|
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
|||
|
out = out.view(m_batchsize, C, height, width)
|
|||
|
|
|||
|
out = self.gamma*out + x
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class PAM(nn.Module):
|
|||
|
""" Position attention module"""
|
|||
|
#Ref from SAGAN
|
|||
|
def __init__(self, in_channels):
|
|||
|
super(PAM, self).__init__()
|
|||
|
inter_channels = in_channels // 4
|
|||
|
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|||
|
nn.BatchNorm2d(inter_channels),
|
|||
|
nn.ReLU())
|
|||
|
|
|||
|
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|||
|
nn.BatchNorm2d(inter_channels),
|
|||
|
nn.ReLU())
|
|||
|
|
|||
|
self.sa = PAM_Module(inter_channels)
|
|||
|
def forward(self,x):
|
|||
|
feat1 = self.conv5a(x) # 降维4倍维度变成512
|
|||
|
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
|
|||
|
return sa_feat
|
|||
|
class CARB(nn.Module):
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CARB, self).__init__()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
|
|||
|
x0=self.cbr(x)
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class unetCAB0Decoder(nn.Module):
|
|||
|
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
|
|||
|
super(unetCAB0Decoder, self).__init__()
|
|||
|
|
|||
|
'''b1的中间层输出:
|
|||
|
64 128 320 512
|
|||
|
'''
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
|
|||
|
|
|||
|
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
|||
|
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
|||
|
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''x:32被下采样的 16*16*512
|
|||
|
组成:cbr(conv1*1 降维)+ 上采样+concate+
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
class unetCAM0Decoder(nn.Module):
|
|||
|
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
|
|||
|
super(unetCAM0Decoder, self).__init__()
|
|||
|
|
|||
|
'''b1的中间层输出:
|
|||
|
64 128 320 512
|
|||
|
1/4 1/8 1/16 1/32
|
|||
|
'''
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.cam32 = CAM(512)
|
|||
|
self.cam16 = CAM(320)
|
|||
|
self.cam8 = CAM(128)
|
|||
|
self.cam4 = CAM(64)
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''x:32被下采样的 16*16*512
|
|||
|
组成:cbr(conv1*1 降维)+ 上采样+concate+
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x4=self.cam4(x4)
|
|||
|
x8=self.cam4(x8)
|
|||
|
x16=self.cam4(x16)
|
|||
|
x32=self.cam4(x32)
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
|
|||
|
super(unetDecoder, self).__init__()
|
|||
|
|
|||
|
'''mitb1的中间层输出:
|
|||
|
64 128 320 512
|
|||
|
|
|||
|
renet50
|
|||
|
256 512 1024 2048
|
|||
|
'''
|
|||
|
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
class unetpamDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamDecoder, self).__init__()
|
|||
|
|
|||
|
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
|
|||
|
'''
|
|||
|
256 512 1024 512
|
|||
|
'''
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam = PAM(in_channels=2048)
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32 = self.pam(x32)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamCARBDecoder(nn.Module):
|
|||
|
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
|
|||
|
super(unetpamCARBDecoder, self).__init__()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam=PAM(in_channels=2048)
|
|||
|
self.carb32=CARB2(512,out_channels=512)
|
|||
|
self.carb16 = CARB2(in_channels=1024,out_channels=256)
|
|||
|
self.carb8 = CARB2(in_channels=512,out_channels=128)
|
|||
|
self.carb4 = CARB2(in_channels=256,out_channels=64)
|
|||
|
|
|||
|
'''
|
|||
|
resnet 256 512 1024 2048
|
|||
|
pam后变成了256 512 1024 512
|
|||
|
carb--1.>128 256 512 512
|
|||
|
carb--2..>64 128 256 512
|
|||
|
|
|||
|
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
|
|||
|
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
|
|||
|
'''
|
|||
|
|
|||
|
|
|||
|
self.up_concat3=fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
|||
|
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
|||
|
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x32=self.carb32(self.pam(x32))
|
|||
|
x4=self.carb4(x4)
|
|||
|
x8=self.carb8(x8)
|
|||
|
x16=self.carb16(x16)
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamDecoderzuhe(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamDecoderzuhe, self).__init__()
|
|||
|
|
|||
|
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
|
|||
|
'''
|
|||
|
256 512 1024 512
|
|||
|
'''
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam = PAM(in_channels=2048)
|
|||
|
self.carb32 = CARB(512, out_channels=512)
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32 = self.carb32(self.pam(x32))
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamcamDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamcamDecoder, self).__init__()
|
|||
|
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam = PAM(in_channels=2048)
|
|||
|
self.cam32 = CAM(512)
|
|||
|
self.cam16 = CAM(in_channels=1024,)
|
|||
|
self.cam8 = CAM(in_channels=512,)
|
|||
|
self.cam4 = CAM(in_channels=256,)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32 = self.cam32(self.pam(x32))
|
|||
|
x4 = self.cam4(x4)
|
|||
|
x8 = self.cam8(x8)
|
|||
|
x16 = self.cam16(x16)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class CARB2(nn.Module):
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CARB2, self).__init__()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
|
|||
|
x0=self.cbr(x)
|
|||
|
xavg = self.global_avgpooling(x0)
|
|||
|
xmax = self.global_maxpooling(x0)
|
|||
|
x=xavg+xmax
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
|
|||
|
return nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
|
|||
|
norm_layer(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
'''另一种融合
|
|||
|
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维
|
|||
|
'''
|
|||
|
|