ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/decoder2.py

716 lines
23 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
'''配合下采样不是32倍数的'''
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
''
#深度可分离卷积基本模块
class conv_dw(nn.Module):
def __init__(self,inp, oup, stride = 1):
super(conv_dw, self).__init__()
self.basedw=nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.basedw(x)
class DepwithDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv_dw = nn.Sequential(
conv_dw(in_channels,mid_channels),
conv_dw(mid_channels,out_channels)
)
def forward(self,x):
return self.double_conv_dw(x)
class up_fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(up_fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
class fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
'''原版cab'''
class CABy(nn.Module):
def __init__(self, in_channels, out_channels):
super(CABy, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class CAB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, high,low):
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
x0 = torch.cat([high,low],dim=1)
x0=self.cbr(x0)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class CAM(nn.Module):
def __init__(self, in_channels,ratio=8):
super(CAM, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x0):
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class PAM(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_channels):
super(PAM, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
def forward(self,x):
feat1 = self.conv5a(x) # 降维4倍维度变成512
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
return sa_feat
class CARB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class unetCAB0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAB0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetCAM0Decoder(nn.Module):
def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
super(unetCAM0Decoder, self).__init__()
'''b1的中间层输出
64 128 320 512
1/4 1/8 1/16 1/32
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.cam32 = CAM(512)
self.cam16 = CAM(320)
self.cam8 = CAM(128)
self.cam4 = CAM(64)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''x32被下采样的 16*16*512
组成cbrconv1*1 降维+ 上采样+concate+
'''
x4,x8,x16,x32=inputlist
x4=self.cam4(x4)
x8=self.cam4(x8)
x16=self.cam4(x16)
x32=self.cam4(x32)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[192, 448, 832], out_filters=[64, 128, 320]):
super(unetDecoder, self).__init__()
'''mitb1的中间层输出
64 128 320 512
renet50
256 512 1024 2048
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoder, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.pam(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamCARBDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
super(unetpamCARBDecoder, self).__init__()
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam=PAM(in_channels=2048)
self.carb32=CARB2(512,out_channels=512)
self.carb16 = CARB2(in_channels=1024,out_channels=256)
self.carb8 = CARB2(in_channels=512,out_channels=128)
self.carb4 = CARB2(in_channels=256,out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3=fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x32=self.carb32(self.pam(x32))
x4=self.carb4(x4)
x8=self.carb8(x8)
x16=self.carb16(x16)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoderzuhe(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoderzuhe, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.carb32 = CARB(512, out_channels=512)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.carb32(self.pam(x32))
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamcamDecoder(nn.Module):
def __init__(self, nclass=2,in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamcamDecoder, self).__init__()
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.cam32 = CAM(512)
self.cam16 = CAM(in_channels=1024,)
self.cam8 = CAM(in_channels=512,)
self.cam4 = CAM(in_channels=256,)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.cam32(self.pam(x32))
x4 = self.cam4(x4)
x8 = self.cam8(x8)
x16 = self.cam16(x16)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class CARB2(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB2, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
xavg = self.global_avgpooling(x0)
xmax = self.global_maxpooling(x0)
x=xavg+xmax
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
norm_layer(out_channels),
nn.ReLU(inplace=True)
)
'''另一种融合
concate 然后扔给一个se 目的是对不同的特征加权自适应的选择不同的特征后 然后降维
'''