ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/decoder.py

1057 lines
34 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import numpy as np
import torch.nn as nn
import torch
from torch.nn import Softmax
import torch.nn.functional as F
#深度可分离卷积基本模块
class conv_dw(nn.Module):
def __init__(self,inp, oup, stride = 1):
super(conv_dw, self).__init__()
self.basedw=nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.basedw(x)
class DepwithDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv_dw = nn.Sequential(
conv_dw(in_channels,mid_channels),
conv_dw(mid_channels,out_channels)
)
def forward(self,x):
return self.double_conv_dw(x)
class up_fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(up_fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential( #最原始版本
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# # # self.cbr = nn.Sequential(
# # # nn.Conv2d(in_channels, out_channels, 1),
# # # nn.BatchNorm2d(out_channels),
# # # nn.ReLU(inplace=True)
# # # )
# self.cbr=DepwithDoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
'''原版cab'''
class CABy(nn.Module):
def __init__(self, in_channels, out_channels):
super(CABy, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class CAB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, high,low):
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
x0 = torch.cat([high,low],dim=1)
x0=self.cbr(x0)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class CAM(nn.Module):
def __init__(self, in_channels,ratio=8):
super(CAM, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x0):
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class PAM(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_channels):
super(PAM, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
def forward(self,x):
feat1 = self.conv5a(x) # 降维4倍维度变成512
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
return sa_feat
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def __init__(self, in_dim):
super(CrissCrossAttention,self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
'''我大概是明白了应该是自注意力在h和w方向拆开分别做qkv attention和v也分别相乘再分开'''
m_batchsize, _, height, width = x.size() #64 32 32
proj_query = self.query_conv(x)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
proj_value = self.value_conv(x)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x
class CARB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class unetCAB0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAB0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetCAM0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAM0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
1/4 1/8 1/16 1/32
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.cam32 = CAM(512)
self.cam16 = CAM(320)
self.cam8 = CAM(128)
self.cam4 = CAM(64)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x4=self.cam4(x4)
x8=self.cam4(x8)
x16=self.cam4(x16)
x32=self.cam4(x32)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
super(unetDecoder, self).__init__()
# '''
# res50的中间层输出
# 256 512 1024 2048
# 1/4(stage1) 1/8 1/16 1/32
#
# decoder如果什么都不加的话 最后的输出中间来个维度转换1*1cbr---512
# 256 512 1024 512
# 1/4(stage1) 1/8 1/16 1/32
#
# in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]
#
#
#
# '''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans=nn.Sequential(
# nn.Conv2d(2048,512,3,1,1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32=self.trans(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoder, self).__init__()
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
self.pam = PAM(in_channels=2048) #他出来是1/4的通道数
# self.pam = PAM(in_channels=512) #他出来是1/4的通道数
# self.pam = CrissCrossAttention(in_dim=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32 = self.pam(self.trans(x32))
x32 = self.pam(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamCARBDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
super(unetpamCARBDecoder, self).__init__()
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam=PAM(in_channels=2048)
self.carb32=CARB2(512,out_channels=512)
self.carb16 = CARB2(in_channels=1024,out_channels=256)
self.carb8 = CARB2(in_channels=512,out_channels=128)
self.carb4 = CARB2(in_channels=256,out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x32=self.carb32(self.pam(x32))
x4=self.carb4(x4)
x8=self.carb8(x8)
x16=self.carb16(x16)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoderzuhe(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoderzuhe, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.carb32 = CARB(512, out_channels=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.carb32(self.pam(x32))
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamcamDecoder(nn.Module):
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamcamDecoder, self).__init__()
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.cam32 = CAM(512)
self.cam16 = CAM(in_channels=1024,)
self.cam8 = CAM(in_channels=512,)
self.cam4 = CAM(in_channels=256,)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.cam32(self.pam(x32))
x4 = self.cam4(x4)
x8 = self.cam8(x8)
x16 = self.cam16(x16)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class CARB2(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB2, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
xavg = self.global_avgpooling(x0)
xmax = self.global_maxpooling(x0)
x=xavg+xmax
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
norm_layer(out_channels),
nn.ReLU(inplace=True)
)
############________________________________________________
class AtrousSeparableConvolution(nn.Module):
""" Atrous Separable Convolution
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True):
super(AtrousSeparableConvolution, self).__init__()
self.body = nn.Sequential(
# Separable Conv
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=in_channels),
# PointWise Conv
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
)
self._init_weight()
def forward(self, x):
return self.body(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)))
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1), )
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class unetasppDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetasppDecoder, self).__init__()
'''
ASPP出来是256
256 512 1024 512
'''
self.nclas = nclass
# self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.trans = nn.Sequential(
# nn.Conv2d(2048, 512, 1, 1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
'''x4-x16:256 512 1024 '''
self.projectx4 = nn.Sequential(
nn.Conv2d(256, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
)
self.projectx8 = nn.Sequential(
nn.Conv2d(512, 256, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.projectx16 = nn.Sequential(
nn.Conv2d(1024, 512, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
self.pam = ASPP(in_channels=2048,atrous_rates=[5, 7, 11]) #他出来是1/4的通道数
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
# x32 = self.pam(self.trans(x32))
x32 = self.pam(x32)
x16=self.projectx16(x16)
x8=self.projectx8(x8)
x4=self.projectx4(x4)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
'''另一种融合
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 然后降维
'''
#########################################################################
'''关于这个manet我最后再给你一次机会'''
class unetDecoder2(nn.Module):
def __init__(self, nclass=2, in_filters = [384, 768, 1536], out_filters = [128, 256, 512]):
super(unetDecoder2, self).__init__()
# '''
# res50的中间层输出
# 256 512 1024 2048
# 1/4(stage1) 1/8 1/16 1/32
#这四个统一加一个1*1卷积进行降维降一半
# 128 256 512 1024
# in_filters = [384, 768, 1536], out_filters = [128, 256, 512]
#
#
# '''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.transx32=nn.Sequential(
nn.Conv2d(2048,1024,1,1),
nn.BatchNorm2d(1024),
nn.ReLU()
)
self.transx16 = nn.Sequential(
nn.Conv2d(1024, 512, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.transx8 = nn.Sequential(
nn.Conv2d(512, 256, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.transx4 = nn.Sequential(
nn.Conv2d(256, 128, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32=self.transx32(x32)
x16=self.transx16(x16)
x8=self.transx8(x8)
x4=self.transx4(x4)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out