ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/decoder2.py

716 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
'''配合下采样不是32倍数的'''
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
''
#深度可分离卷积基本模块
class conv_dw(nn.Module):
def __init__(self,inp, oup, stride = 1):
super(conv_dw, self).__init__()
self.basedw=nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.basedw(x)
class DepwithDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv_dw = nn.Sequential(
conv_dw(in_channels,mid_channels),
conv_dw(mid_channels,out_channels)
)
def forward(self,x):
return self.double_conv_dw(x)
class up_fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(up_fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
class fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
'''原版cab'''
class CABy(nn.Module):
def __init__(self, in_channels, out_channels):
super(CABy, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class CAB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, high,low):
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
x0 = torch.cat([high,low],dim=1)
x0=self.cbr(x0)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class CAM(nn.Module):
def __init__(self, in_channels,ratio=8):
super(CAM, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x0):
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class PAM(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_channels):
super(PAM, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
def forward(self,x):
feat1 = self.conv5a(x) # 降维4倍维度变成512
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
return sa_feat
class CARB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class unetCAB0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAB0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维)+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetCAM0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAM0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
1/4 1/8 1/16 1/32
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.cam32 = CAM(512)
self.cam16 = CAM(320)
self.cam8 = CAM(128)
self.cam4 = CAM(64)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维)+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x4=self.cam4(x4)
x8=self.cam4(x8)
x16=self.cam4(x16)
x32=self.cam4(x32)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
super(unetDecoder, self).__init__()
'''mitb1的中间层输出
64 128 320 512
renet50
256 512 1024 2048
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoder, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.pam(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamCARBDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
super(unetpamCARBDecoder, self).__init__()
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam=PAM(in_channels=2048)
self.carb32=CARB2(512,out_channels=512)
self.carb16 = CARB2(in_channels=1024,out_channels=256)
self.carb8 = CARB2(in_channels=512,out_channels=128)
self.carb4 = CARB2(in_channels=256,out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3=fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x32=self.carb32(self.pam(x32))
x4=self.carb4(x4)
x8=self.carb8(x8)
x16=self.carb16(x16)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoderzuhe(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoderzuhe, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.carb32 = CARB(512, out_channels=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.carb32(self.pam(x32))
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamcamDecoder(nn.Module):
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamcamDecoder, self).__init__()
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.cam32 = CAM(512)
self.cam16 = CAM(in_channels=1024,)
self.cam8 = CAM(in_channels=512,)
self.cam4 = CAM(in_channels=256,)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.cam32(self.pam(x32))
x4 = self.cam4(x4)
x8 = self.cam8(x8)
x16 = self.cam16(x16)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class CARB2(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB2, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
xavg = self.global_avgpooling(x0)
xmax = self.global_maxpooling(x0)
x=xavg+xmax
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
norm_layer(out_channels),
nn.ReLU(inplace=True)
)
'''另一种融合
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 ,然后降维
'''