ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/decoder.py

1057 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import numpy as np
import torch.nn as nn
import torch
from torch.nn import Softmax
import torch.nn.functional as F
#深度可分离卷积基本模块
class conv_dw(nn.Module):
def __init__(self,inp, oup, stride = 1):
super(conv_dw, self).__init__()
self.basedw=nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.basedw(x)
class DepwithDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv_dw = nn.Sequential(
conv_dw(in_channels,mid_channels),
conv_dw(mid_channels,out_channels)
)
def forward(self,x):
return self.double_conv_dw(x)
class up_fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(up_fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential( #最原始版本
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# # # self.cbr = nn.Sequential(
# # # nn.Conv2d(in_channels, out_channels, 1),
# # # nn.BatchNorm2d(out_channels),
# # # nn.ReLU(inplace=True)
# # # )
# self.cbr=DepwithDoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
'''原版cab'''
class CABy(nn.Module):
def __init__(self, in_channels, out_channels):
super(CABy, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class CAB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, high,low):
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
x0 = torch.cat([high,low],dim=1)
x0=self.cbr(x0)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class CAM(nn.Module):
def __init__(self, in_channels,ratio=8):
super(CAM, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x0):
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class PAM(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_channels):
super(PAM, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
def forward(self,x):
feat1 = self.conv5a(x) # 降维4倍维度变成512
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
return sa_feat
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def __init__(self, in_dim):
super(CrissCrossAttention,self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
'''我大概是明白了应该是自注意力在h和w方向拆开分别做qkv attention和v也分别相乘再分开'''
m_batchsize, _, height, width = x.size() #64 32 32
proj_query = self.query_conv(x)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
proj_value = self.value_conv(x)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x
class CARB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class unetCAB0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAB0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维)+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetCAM0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAM0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
1/4 1/8 1/16 1/32
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.cam32 = CAM(512)
self.cam16 = CAM(320)
self.cam8 = CAM(128)
self.cam4 = CAM(64)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维)+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x4=self.cam4(x4)
x8=self.cam4(x8)
x16=self.cam4(x16)
x32=self.cam4(x32)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
super(unetDecoder, self).__init__()
# '''
# res50的中间层输出
# 256 512 1024 2048
# 1/4(stage1) 1/8 1/16 1/32
#
# decoder如果什么都不加的话 最后的输出中间来个维度转换1*1cbr---512
# 256 512 1024 512
# 1/4(stage1) 1/8 1/16 1/32
#
# in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]
#
#
#
# '''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans=nn.Sequential(
# nn.Conv2d(2048,512,3,1,1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32=self.trans(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoder, self).__init__()
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
self.pam = PAM(in_channels=2048) #他出来是1/4的通道数
# self.pam = PAM(in_channels=512) #他出来是1/4的通道数
# self.pam = CrissCrossAttention(in_dim=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32 = self.pam(self.trans(x32))
x32 = self.pam(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamCARBDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
super(unetpamCARBDecoder, self).__init__()
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam=PAM(in_channels=2048)
self.carb32=CARB2(512,out_channels=512)
self.carb16 = CARB2(in_channels=1024,out_channels=256)
self.carb8 = CARB2(in_channels=512,out_channels=128)
self.carb4 = CARB2(in_channels=256,out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x32=self.carb32(self.pam(x32))
x4=self.carb4(x4)
x8=self.carb8(x8)
x16=self.carb16(x16)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoderzuhe(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoderzuhe, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.carb32 = CARB(512, out_channels=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.carb32(self.pam(x32))
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamcamDecoder(nn.Module):
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamcamDecoder, self).__init__()
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.cam32 = CAM(512)
self.cam16 = CAM(in_channels=1024,)
self.cam8 = CAM(in_channels=512,)
self.cam4 = CAM(in_channels=256,)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.cam32(self.pam(x32))
x4 = self.cam4(x4)
x8 = self.cam8(x8)
x16 = self.cam16(x16)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class CARB2(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB2, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
xavg = self.global_avgpooling(x0)
xmax = self.global_maxpooling(x0)
x=xavg+xmax
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
norm_layer(out_channels),
nn.ReLU(inplace=True)
)
############________________________________________________
class AtrousSeparableConvolution(nn.Module):
""" Atrous Separable Convolution
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True):
super(AtrousSeparableConvolution, self).__init__()
self.body = nn.Sequential(
# Separable Conv
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=in_channels),
# PointWise Conv
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
)
self._init_weight()
def forward(self, x):
return self.body(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)))
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1), )
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class unetasppDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetasppDecoder, self).__init__()
'''
ASPP出来是256
256 512 1024 512
'''
self.nclas = nclass
# self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
'''x4-x16:256 512 1024 '''
self.projectx4 = nn.Sequential(
nn.Conv2d(256, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
)
self.projectx8 = nn.Sequential(
nn.Conv2d(512, 256, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.projectx16 = nn.Sequential(
nn.Conv2d(1024, 512, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
self.pam = ASPP(in_channels=2048,atrous_rates=[5, 7, 11]) #他出来是1/4的通道数
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32 = self.pam(self.trans(x32))
x32 = self.pam(x32)
x16=self.projectx16(x16)
x8=self.projectx8(x8)
x4=self.projectx4(x4)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
'''另一种融合
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维
'''
#########################################################################
'''关于这个manet我最后再给你一次机会'''
class unetDecoder2(nn.Module):
def __init__(self, nclass=2, in_filters = [384, 768, 1536], out_filters = [128, 256, 512]):
super(unetDecoder2, self).__init__()
# '''
# res50的中间层输出
# 256 512 1024 2048
# 1/4(stage1) 1/8 1/16 1/32
#这四个统一加一个1*1卷积进行降维降一半
# 128 256 512 1024
# in_filters = [384, 768, 1536], out_filters = [128, 256, 512]
#
#
# '''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.transx32=nn.Sequential(
nn.Conv2d(2048,1024,1,1),
nn.BatchNorm2d(1024),
nn.ReLU()
)
self.transx16 = nn.Sequential(
nn.Conv2d(1024, 512, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.transx8 = nn.Sequential(
nn.Conv2d(512, 256, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.transx4 = nn.Sequential(
nn.Conv2d(256, 128, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32=self.transx32(x32)
x16=self.transx16(x16)
x8=self.transx8(x8)
x4=self.transx4(x4)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out