ai-station-code/wudingpv/taihuyuan_roof/manet/model/decoder.py

597 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
#深度可分离卷积基本模块
class conv_dw(nn.Module):
def __init__(self,inp, oup, stride = 1):
super(conv_dw, self).__init__()
self.basedw=nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.basedw(x)
class DepwithDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv_dw = nn.Sequential(
conv_dw(in_channels,mid_channels),
conv_dw(mid_channels,out_channels)
)
def forward(self,x):
return self.double_conv_dw(x)
class up_fusionblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(up_fusionblock, self).__init__()
self.in_channels= in_channels
self.out_channels=out_channels
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.cbr=nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.cbr(x)
'''原版cab'''
class CABy(nn.Module):
def __init__(self, in_channels, out_channels):
super(CABy, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
x = torch.cat([x1,x2],dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class CAB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, high,low):
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
x0 = torch.cat([high,low],dim=1)
x0=self.cbr(x0)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class CAM(nn.Module):
def __init__(self, in_channels,ratio=8):
super(CAM, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x0):
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class PAM_Module(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class PAM(nn.Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_channels):
super(PAM, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
def forward(self,x):
feat1 = self.conv5a(x) # 降维4倍维度变成512
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
return sa_feat
class CARB(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
x = self.global_pooling(x0)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x
class unetDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[192, 448, 832],out_filters = [64, 128, 320]):
super(unetDecoder, self).__init__()
'''res50
2 4 8 16 32
64 256 512 1024 2048
'''
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
super(unetpamDecoder, self).__init__()
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
'''
256 512 1024 512
'''
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam = PAM(in_channels=2048)
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.pam(x32)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
# class unetCAB0Decoder(nn.Module):
# def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
# super(unetCAB0Decoder, self).__init__()
#
# '''b1的中间层输出
# 64 128 320 512
# '''
#
# self.nclas=nclass
# self.finnal_channel=512
#
#
# self.in_filters = in_filters
# self.out_filters = out_filters
#
#
# self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
# self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
# self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
#
# self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
#
# self._init_weight()
#
#
#
# def _init_weight(self):
# print("decoder从初始化执行")
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight)
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
#
# def forward(self,inputlist): #512为例
# '''x32被下采样的 16*16*512
# 组成cbrconv1*1 降维)+ 上采样+concate+
# '''
# x4,x8,x16,x32=inputlist
# x16= self.up_concat3(x32,x16)
# x8= self.up_concat2(x16,x8)
# x4= self.up_concat1(x8,x4)
# x4=self.classifer(x4)
# out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
# return out
class unetCARBDecoder(nn.Module):
def __init__(self, nclass=2, in_filters=[192, 384, 768], out_filters=[64, 128, 256]):
super(unetCARBDecoder, self).__init__()
self.nclas = nclass
self.finnal_channel = 512
self.in_filters = in_filters
self.out_filters = out_filters
# self.pam=PAM(in_channels=2048)
self.pam = nn.Sequential(
nn.Conv2d(2048, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.carb32 = CARB2(512, out_channels=512)
self.carb16 = CARB2(in_channels=1024, out_channels=256)
self.carb8 = CARB2(in_channels=512, out_channels=128)
self.carb4 = CARB2(in_channels=256, out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputlist): # 512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4, x8, x16, x32 = inputlist
x32 = self.carb32(self.pam(x32))
x4 = self.carb4(x4)
x8 = self.carb8(x8)
x16 = self.carb16(x16)
x16 = self.up_concat3(x32, x16)
x8 = self.up_concat2(x16, x8)
x4 = self.up_concat1(x8, x4)
x4 = self.classifer(x4)
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
class unetpamCARBDecoder(nn.Module):
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
super(unetpamCARBDecoder, self).__init__()
self.nclas=nclass
self.finnal_channel=512
self.in_filters = in_filters
self.out_filters = out_filters
self.pam=PAM(in_channels=2048)
# self.pam = nn.Sequential(
# nn.Conv2d(2048, 512, 3, 1,1),
# nn.BatchNorm2d(512),
# nn.ReLU()
#
# )
self.carb32=CARB2(512,out_channels=512)
self.carb16 = CARB2(in_channels=1024,out_channels=256)
self.carb8 = CARB2(in_channels=512,out_channels=128)
self.carb4 = CARB2(in_channels=256,out_channels=64)
'''
resnet 256 512 1024 2048
pam后变成了256 512 1024 512
carb--1.>128 256 512 512
carb--2..>64 128 256 512
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
'''
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
self._init_weight()
def _init_weight(self):
print("decoder从初始化执行")
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,inputlist): #512为例
'''
mit
x4128 128 64
x8:64 64 128
x1632 32 320
x32: 16 16 512
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
resnet
x4256 128 128
x8:512 64 128
x161024 32 320
x32: 2048 16 512
'''
x4,x8,x16,x32=inputlist
x32=self.carb32(self.pam(x32))
x4=self.carb4(x4)
x8=self.carb8(x8)
x16=self.carb16(x16)
x16= self.up_concat3(x32,x16)
x8= self.up_concat2(x16,x8)
x4= self.up_concat1(x8,x4)
x4=self.classifer(x4)
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
return out
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
norm_layer(out_channels),
nn.ReLU(inplace=True)
)
# ##AFS
# # 抄的Semantic Segmentation for Remote Sensing Images
# # Based on Adaptive Feature Selection Network#letter的能work才怪#########################
# class afs():
#
class CARB2(nn.Module):
def __init__(self, in_channels,out_channels,ratio=8):
super(CARB2, self).__init__()
self.cbr = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x0=self.cbr(x)
xavg = self.global_avgpooling(x0)
xmax = self.global_maxpooling(x0)
x=xavg+xmax
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x = x * x0
return x