1057 lines
34 KiB
Python
1057 lines
34 KiB
Python
|
# ---------------------------------------------------------------
|
|||
|
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
|||
|
#
|
|||
|
# This work is licensed under the NVIDIA Source Code License
|
|||
|
# ---------------------------------------------------------------
|
|||
|
import numpy as np
|
|||
|
import torch.nn as nn
|
|||
|
import torch
|
|||
|
from torch.nn import Softmax
|
|||
|
import torch.nn.functional as F
|
|||
|
|
|||
|
|
|||
|
|
|||
|
#深度可分离卷积基本模块
|
|||
|
class conv_dw(nn.Module):
|
|||
|
def __init__(self,inp, oup, stride = 1):
|
|||
|
super(conv_dw, self).__init__()
|
|||
|
self.basedw=nn.Sequential(
|
|||
|
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
|||
|
nn.BatchNorm2d(inp),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
|
|||
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
|||
|
nn.BatchNorm2d(oup),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
)
|
|||
|
def forward(self,x):
|
|||
|
return self.basedw(x)
|
|||
|
|
|||
|
class DepwithDoubleConv(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels, out_channels, mid_channels=None):
|
|||
|
super().__init__()
|
|||
|
if not mid_channels:
|
|||
|
mid_channels = out_channels
|
|||
|
self.double_conv_dw = nn.Sequential(
|
|||
|
conv_dw(in_channels,mid_channels),
|
|||
|
conv_dw(mid_channels,out_channels)
|
|||
|
)
|
|||
|
|
|||
|
def forward(self,x):
|
|||
|
return self.double_conv_dw(x)
|
|||
|
|
|||
|
|
|||
|
class up_fusionblock(nn.Module):
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(up_fusionblock, self).__init__()
|
|||
|
self.in_channels= in_channels
|
|||
|
self.out_channels=out_channels
|
|||
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
self.cbr=nn.Sequential( #最原始版本
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
# # # self.cbr = nn.Sequential(
|
|||
|
# # # nn.Conv2d(in_channels, out_channels, 1),
|
|||
|
# # # nn.BatchNorm2d(out_channels),
|
|||
|
# # # nn.ReLU(inplace=True)
|
|||
|
# # # )
|
|||
|
# self.cbr=DepwithDoubleConv(in_channels, out_channels)
|
|||
|
|
|||
|
def forward(self, x1, x2):
|
|||
|
x1 = self.up(x1)
|
|||
|
x = torch.cat([x2, x1], dim=1)
|
|||
|
return self.cbr(x)
|
|||
|
|
|||
|
'''原版cab'''
|
|||
|
class CABy(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(CABy, self).__init__()
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x1, x2 = x # high, low
|
|||
|
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
|
|||
|
x = torch.cat([x1,x2],dim=1)
|
|||
|
x = self.global_pooling(x)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x2 = x * x2
|
|||
|
res = x2 + x1
|
|||
|
return res
|
|||
|
|
|||
|
class CAB(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CAB, self).__init__()
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, high,low):
|
|||
|
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
|
|||
|
x0 = torch.cat([high,low],dim=1)
|
|||
|
x0=self.cbr(x0)
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
class CAM(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_channels,ratio=8):
|
|||
|
super(CAM, self).__init__()
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
def forward(self, x0):
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
class PAM_Module(nn.Module):
|
|||
|
""" Position attention module"""
|
|||
|
#Ref from SAGAN
|
|||
|
def __init__(self, in_dim):
|
|||
|
super(PAM_Module, self).__init__()
|
|||
|
self.chanel_in = in_dim
|
|||
|
|
|||
|
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
|||
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|||
|
|
|||
|
self.softmax = nn.Softmax(dim=-1)
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
inputs :
|
|||
|
x : input feature maps( B X C X H X W)
|
|||
|
returns :
|
|||
|
out : attention value + input feature
|
|||
|
attention: B X (HxW) X (HxW)
|
|||
|
"""
|
|||
|
m_batchsize, C, height, width = x.size()
|
|||
|
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
|
|||
|
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
|
|||
|
energy = torch.bmm(proj_query, proj_key)
|
|||
|
attention = self.softmax(energy)
|
|||
|
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
|
|||
|
|
|||
|
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
|||
|
out = out.view(m_batchsize, C, height, width)
|
|||
|
|
|||
|
out = self.gamma*out + x
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class PAM(nn.Module):
|
|||
|
""" Position attention module"""
|
|||
|
#Ref from SAGAN
|
|||
|
def __init__(self, in_channels):
|
|||
|
super(PAM, self).__init__()
|
|||
|
inter_channels = in_channels // 4
|
|||
|
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|||
|
nn.BatchNorm2d(inter_channels),
|
|||
|
nn.ReLU())
|
|||
|
|
|||
|
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|||
|
nn.BatchNorm2d(inter_channels),
|
|||
|
nn.ReLU())
|
|||
|
|
|||
|
self.sa = PAM_Module(inter_channels)
|
|||
|
def forward(self,x):
|
|||
|
feat1 = self.conv5a(x) # 降维4倍维度变成512
|
|||
|
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
|
|||
|
return sa_feat
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def INF(B,H,W):
|
|||
|
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
|
|||
|
|
|||
|
|
|||
|
class CrissCrossAttention(nn.Module):
|
|||
|
""" Criss-Cross Attention Module"""
|
|||
|
def __init__(self, in_dim):
|
|||
|
super(CrissCrossAttention,self).__init__()
|
|||
|
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
|||
|
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
|||
|
self.softmax = Softmax(dim=3)
|
|||
|
self.INF = INF
|
|||
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|||
|
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
'''我大概是明白了应该是自注意力在h和w方向拆开分别做qkv attention和v也分别相乘再分开'''
|
|||
|
m_batchsize, _, height, width = x.size() #64 32 32
|
|||
|
proj_query = self.query_conv(x)
|
|||
|
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
|
|||
|
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
|
|||
|
proj_key = self.key_conv(x)
|
|||
|
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
|
|||
|
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
|
|||
|
proj_value = self.value_conv(x)
|
|||
|
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
|
|||
|
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
|
|||
|
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
|
|||
|
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
|
|||
|
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
|
|||
|
|
|||
|
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
|
|||
|
#print(concate)
|
|||
|
#print(att_H)
|
|||
|
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
|
|||
|
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
|
|||
|
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
|
|||
|
#print(out_H.size(),out_W.size())
|
|||
|
return self.gamma*(out_H + out_W) + x
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class CARB(nn.Module):
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CARB, self).__init__()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
|
|||
|
x0=self.cbr(x)
|
|||
|
x = self.global_pooling(x0)
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class unetCAB0Decoder(nn.Module):
|
|||
|
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
|
|||
|
super(unetCAB0Decoder, self).__init__()
|
|||
|
|
|||
|
'''b1的中间层输出:
|
|||
|
64 128 320 512
|
|||
|
'''
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
|
|||
|
|
|||
|
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
|||
|
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
|||
|
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''x:32被下采样的 16*16*512
|
|||
|
组成:cbr(conv1*1 降维)+ 上采样+concate+
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
class unetCAM0Decoder(nn.Module):
|
|||
|
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
|
|||
|
super(unetCAM0Decoder, self).__init__()
|
|||
|
|
|||
|
'''b1的中间层输出:
|
|||
|
64 128 320 512
|
|||
|
1/4 1/8 1/16 1/32
|
|||
|
'''
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.cam32 = CAM(512)
|
|||
|
self.cam16 = CAM(320)
|
|||
|
self.cam8 = CAM(128)
|
|||
|
self.cam4 = CAM(64)
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''x:32被下采样的 16*16*512
|
|||
|
组成:cbr(conv1*1 降维)+ 上采样+concate+
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x4=self.cam4(x4)
|
|||
|
x8=self.cam4(x8)
|
|||
|
x16=self.cam4(x16)
|
|||
|
x32=self.cam4(x32)
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
|
|||
|
super(unetDecoder, self).__init__()
|
|||
|
# '''
|
|||
|
# res50的中间层输出:
|
|||
|
# 256 512 1024 2048
|
|||
|
# 1/4(stage1) 1/8 1/16 1/32
|
|||
|
#
|
|||
|
# decoder如果什么都不加的话 最后的输出中间来个维度转换1*1cbr()---512
|
|||
|
# 256 512 1024 512
|
|||
|
# 1/4(stage1) 1/8 1/16 1/32
|
|||
|
#
|
|||
|
# in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]
|
|||
|
#
|
|||
|
#
|
|||
|
#
|
|||
|
# '''
|
|||
|
|
|||
|
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
# self.trans=nn.Sequential(
|
|||
|
# nn.Conv2d(2048,512,3,1,1),
|
|||
|
# nn.BatchNorm2d(512),
|
|||
|
# nn.ReLU()
|
|||
|
#
|
|||
|
# )
|
|||
|
# self.trans = nn.Sequential(
|
|||
|
# nn.Conv2d(2048, 512, 1, 1),
|
|||
|
# nn.BatchNorm2d(512),
|
|||
|
# nn.ReLU()
|
|||
|
#
|
|||
|
# )
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
# x32=self.trans(x32)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
class unetpamDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamDecoder, self).__init__()
|
|||
|
|
|||
|
'''
|
|||
|
256 512 1024 512
|
|||
|
'''
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
|
|||
|
# self.trans = nn.Sequential(
|
|||
|
# nn.Conv2d(2048, 512, 1, 1),
|
|||
|
# nn.BatchNorm2d(512),
|
|||
|
# nn.ReLU()
|
|||
|
#
|
|||
|
# )
|
|||
|
|
|||
|
self.pam = PAM(in_channels=2048) #他出来是1/4的通道数
|
|||
|
# self.pam = PAM(in_channels=512) #他出来是1/4的通道数
|
|||
|
# self.pam = CrissCrossAttention(in_dim=512)
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
# x32 = self.pam(self.trans(x32))
|
|||
|
x32 = self.pam(x32)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamCARBDecoder(nn.Module):
|
|||
|
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
|
|||
|
super(unetpamCARBDecoder, self).__init__()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
self.nclas=nclass
|
|||
|
self.finnal_channel=512
|
|||
|
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam=PAM(in_channels=2048)
|
|||
|
|
|||
|
self.carb32=CARB2(512,out_channels=512)
|
|||
|
self.carb16 = CARB2(in_channels=1024,out_channels=256)
|
|||
|
self.carb8 = CARB2(in_channels=512,out_channels=128)
|
|||
|
self.carb4 = CARB2(in_channels=256,out_channels=64)
|
|||
|
|
|||
|
'''
|
|||
|
resnet 256 512 1024 2048
|
|||
|
pam后变成了256 512 1024 512
|
|||
|
carb--1.>128 256 512 512
|
|||
|
carb--2..>64 128 256 512
|
|||
|
|
|||
|
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
|
|||
|
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
|
|||
|
'''
|
|||
|
|
|||
|
|
|||
|
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
|||
|
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
|||
|
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
|||
|
|
|||
|
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self,inputlist): #512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4,x8,x16,x32=inputlist
|
|||
|
x32=self.carb32(self.pam(x32))
|
|||
|
x4=self.carb4(x4)
|
|||
|
x8=self.carb8(x8)
|
|||
|
x16=self.carb16(x16)
|
|||
|
x16= self.up_concat3(x32,x16)
|
|||
|
x8= self.up_concat2(x16,x8)
|
|||
|
x4= self.up_concat1(x8,x4)
|
|||
|
x4=self.classifer(x4)
|
|||
|
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamDecoderzuhe(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamDecoderzuhe, self).__init__()
|
|||
|
|
|||
|
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
|
|||
|
'''
|
|||
|
256 512 1024 512
|
|||
|
'''
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam = PAM(in_channels=2048)
|
|||
|
self.carb32 = CARB(512, out_channels=512)
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32 = self.carb32(self.pam(x32))
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class unetpamcamDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetpamcamDecoder, self).__init__()
|
|||
|
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.pam = PAM(in_channels=2048)
|
|||
|
self.cam32 = CAM(512)
|
|||
|
self.cam16 = CAM(in_channels=1024,)
|
|||
|
self.cam8 = CAM(in_channels=512,)
|
|||
|
self.cam4 = CAM(in_channels=256,)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32 = self.cam32(self.pam(x32))
|
|||
|
x4 = self.cam4(x4)
|
|||
|
x8 = self.cam8(x8)
|
|||
|
x16 = self.cam16(x16)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class CARB2(nn.Module):
|
|||
|
def __init__(self, in_channels,out_channels,ratio=8):
|
|||
|
super(CARB2, self).__init__()
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
|
|||
|
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
|
|||
|
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
|||
|
self.relu = nn.ReLU()
|
|||
|
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
|||
|
self.sigmod = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
|
|||
|
x0=self.cbr(x)
|
|||
|
xavg = self.global_avgpooling(x0)
|
|||
|
xmax = self.global_maxpooling(x0)
|
|||
|
x=xavg+xmax
|
|||
|
x = self.conv1(x)
|
|||
|
x = self.relu(x)
|
|||
|
x = self.conv2(x)
|
|||
|
x = self.sigmod(x)
|
|||
|
x = x * x0
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
|
|||
|
return nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
|
|||
|
norm_layer(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
############________________________________________________
|
|||
|
|
|||
|
class AtrousSeparableConvolution(nn.Module):
|
|||
|
""" Atrous Separable Convolution
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, in_channels, out_channels, kernel_size,
|
|||
|
stride=1, padding=0, dilation=1, bias=True):
|
|||
|
super(AtrousSeparableConvolution, self).__init__()
|
|||
|
self.body = nn.Sequential(
|
|||
|
# Separable Conv
|
|||
|
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|||
|
dilation=dilation, bias=bias, groups=in_channels),
|
|||
|
# PointWise Conv
|
|||
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
|
|||
|
)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
return self.body(x)
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
|
|||
|
class ASPPConv(nn.Sequential):
|
|||
|
def __init__(self, in_channels, out_channels, dilation):
|
|||
|
modules = [
|
|||
|
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
]
|
|||
|
super(ASPPConv, self).__init__(*modules)
|
|||
|
|
|||
|
|
|||
|
class ASPPPooling(nn.Sequential):
|
|||
|
def __init__(self, in_channels, out_channels):
|
|||
|
super(ASPPPooling, self).__init__(
|
|||
|
nn.AdaptiveAvgPool2d(1),
|
|||
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True))
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
size = x.shape[-2:]
|
|||
|
x = super(ASPPPooling, self).forward(x)
|
|||
|
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
|||
|
|
|||
|
|
|||
|
class ASPP(nn.Module):
|
|||
|
def __init__(self, in_channels, atrous_rates):
|
|||
|
super(ASPP, self).__init__()
|
|||
|
out_channels = 256
|
|||
|
modules = []
|
|||
|
modules.append(nn.Sequential(
|
|||
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True)))
|
|||
|
|
|||
|
rate1, rate2, rate3 = tuple(atrous_rates)
|
|||
|
modules.append(ASPPConv(in_channels, out_channels, rate1))
|
|||
|
modules.append(ASPPConv(in_channels, out_channels, rate2))
|
|||
|
modules.append(ASPPConv(in_channels, out_channels, rate3))
|
|||
|
modules.append(ASPPPooling(in_channels, out_channels))
|
|||
|
|
|||
|
self.convs = nn.ModuleList(modules)
|
|||
|
|
|||
|
self.project = nn.Sequential(
|
|||
|
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_channels),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.Dropout(0.1), )
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
res = []
|
|||
|
for conv in self.convs:
|
|||
|
res.append(conv(x))
|
|||
|
res = torch.cat(res, dim=1)
|
|||
|
return self.project(res)
|
|||
|
|
|||
|
|
|||
|
class unetasppDecoder(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
|||
|
super(unetasppDecoder, self).__init__()
|
|||
|
|
|||
|
|
|||
|
'''
|
|||
|
ASPP出来是256
|
|||
|
256 512 1024 512
|
|||
|
'''
|
|||
|
self.nclas = nclass
|
|||
|
# self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
|
|||
|
# self.trans = nn.Sequential(
|
|||
|
# nn.Conv2d(2048, 512, 1, 1),
|
|||
|
# nn.BatchNorm2d(512),
|
|||
|
# nn.ReLU()
|
|||
|
#
|
|||
|
# )
|
|||
|
'''x4-x16:256 512 1024 '''
|
|||
|
self.projectx4 = nn.Sequential(
|
|||
|
nn.Conv2d(256, 128, 1, bias=False),
|
|||
|
nn.BatchNorm2d(128),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
)
|
|||
|
self.projectx8 = nn.Sequential(
|
|||
|
nn.Conv2d(512, 256, 1, bias=False),
|
|||
|
nn.BatchNorm2d(256),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
)
|
|||
|
self.projectx16 = nn.Sequential(
|
|||
|
nn.Conv2d(1024, 512, 1, bias=False),
|
|||
|
nn.BatchNorm2d(512),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
self.pam = ASPP(in_channels=2048,atrous_rates=[5, 7, 11]) #他出来是1/4的通道数
|
|||
|
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
# x32 = self.pam(self.trans(x32))
|
|||
|
x32 = self.pam(x32)
|
|||
|
x16=self.projectx16(x16)
|
|||
|
x8=self.projectx8(x8)
|
|||
|
x4=self.projectx4(x4)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
'''另一种融合
|
|||
|
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维
|
|||
|
'''
|
|||
|
|
|||
|
|
|||
|
#########################################################################
|
|||
|
'''关于这个manet我最后再给你一次机会'''
|
|||
|
|
|||
|
class unetDecoder2(nn.Module):
|
|||
|
def __init__(self, nclass=2, in_filters = [384, 768, 1536], out_filters = [128, 256, 512]):
|
|||
|
super(unetDecoder2, self).__init__()
|
|||
|
# '''
|
|||
|
# res50的中间层输出:
|
|||
|
# 256 512 1024 2048
|
|||
|
# 1/4(stage1) 1/8 1/16 1/32
|
|||
|
#这四个统一加一个1*1卷积进行降维降一半
|
|||
|
# 128 256 512 1024
|
|||
|
# in_filters = [384, 768, 1536], out_filters = [128, 256, 512]
|
|||
|
#
|
|||
|
#
|
|||
|
# '''
|
|||
|
|
|||
|
|
|||
|
self.nclas = nclass
|
|||
|
self.finnal_channel = 512
|
|||
|
|
|||
|
self.in_filters = in_filters
|
|||
|
self.out_filters = out_filters
|
|||
|
self.transx32=nn.Sequential(
|
|||
|
nn.Conv2d(2048,1024,1,1),
|
|||
|
nn.BatchNorm2d(1024),
|
|||
|
nn.ReLU()
|
|||
|
|
|||
|
)
|
|||
|
|
|||
|
self.transx16 = nn.Sequential(
|
|||
|
nn.Conv2d(1024, 512, 1, 1),
|
|||
|
nn.BatchNorm2d(512),
|
|||
|
nn.ReLU()
|
|||
|
|
|||
|
)
|
|||
|
|
|||
|
self.transx8 = nn.Sequential(
|
|||
|
nn.Conv2d(512, 256, 1, 1),
|
|||
|
nn.BatchNorm2d(256),
|
|||
|
nn.ReLU()
|
|||
|
|
|||
|
)
|
|||
|
|
|||
|
self.transx4 = nn.Sequential(
|
|||
|
nn.Conv2d(256, 128, 1, 1),
|
|||
|
nn.BatchNorm2d(128),
|
|||
|
nn.ReLU()
|
|||
|
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
|||
|
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
|||
|
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
|||
|
|
|||
|
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
|||
|
|
|||
|
self._init_weight()
|
|||
|
|
|||
|
def _init_weight(self):
|
|||
|
print("decoder从初始化执行")
|
|||
|
for m in self.modules():
|
|||
|
if isinstance(m, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(m.weight)
|
|||
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|||
|
nn.init.constant_(m.weight, 1)
|
|||
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|||
|
def forward(self, inputlist): # 512为例
|
|||
|
'''
|
|||
|
mit
|
|||
|
x4:128 128 64
|
|||
|
x8:64 64 128
|
|||
|
x16:32 32 320
|
|||
|
x32: 16 16 512
|
|||
|
|
|||
|
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
|||
|
|
|||
|
resnet
|
|||
|
x4:256 128 128
|
|||
|
x8:512 64 128
|
|||
|
x16:1024 32 320
|
|||
|
x32: 2048 16 512
|
|||
|
'''
|
|||
|
x4, x8, x16, x32 = inputlist
|
|||
|
x32=self.transx32(x32)
|
|||
|
x16=self.transx16(x16)
|
|||
|
x8=self.transx8(x8)
|
|||
|
x4=self.transx4(x4)
|
|||
|
x16 = self.up_concat3(x32, x16)
|
|||
|
x8 = self.up_concat2(x16, x8)
|
|||
|
x4 = self.up_concat1(x8, x4)
|
|||
|
x4 = self.classifer(x4)
|
|||
|
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|||
|
return out
|
|||
|
|