597 lines
19 KiB
Python
597 lines
19 KiB
Python
# ---------------------------------------------------------------
|
||
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
||
#
|
||
# This work is licensed under the NVIDIA Source Code License
|
||
# ---------------------------------------------------------------
|
||
import numpy as np
|
||
import torch.nn as nn
|
||
import torch
|
||
|
||
import torch.nn.functional as F
|
||
|
||
|
||
|
||
#深度可分离卷积基本模块
|
||
class conv_dw(nn.Module):
|
||
def __init__(self,inp, oup, stride = 1):
|
||
super(conv_dw, self).__init__()
|
||
self.basedw=nn.Sequential(
|
||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||
nn.BatchNorm2d(inp),
|
||
nn.ReLU(inplace=True),
|
||
|
||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||
nn.BatchNorm2d(oup),
|
||
nn.ReLU(inplace=True),
|
||
)
|
||
def forward(self,x):
|
||
return self.basedw(x)
|
||
|
||
class DepwithDoubleConv(nn.Module):
|
||
|
||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||
super().__init__()
|
||
if not mid_channels:
|
||
mid_channels = out_channels
|
||
self.double_conv_dw = nn.Sequential(
|
||
conv_dw(in_channels,mid_channels),
|
||
conv_dw(mid_channels,out_channels)
|
||
)
|
||
|
||
def forward(self,x):
|
||
return self.double_conv_dw(x)
|
||
|
||
|
||
class up_fusionblock(nn.Module):
|
||
def __init__(self, in_channels, out_channels):
|
||
super(up_fusionblock, self).__init__()
|
||
self.in_channels= in_channels
|
||
self.out_channels=out_channels
|
||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
||
self.cbr=nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
def forward(self, x1, x2):
|
||
x1 = self.up(x1)
|
||
x = torch.cat([x2, x1], dim=1)
|
||
return self.cbr(x)
|
||
|
||
'''原版cab'''
|
||
class CABy(nn.Module):
|
||
|
||
def __init__(self, in_channels, out_channels):
|
||
super(CABy, self).__init__()
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||
self.sigmod = nn.Sigmoid()
|
||
|
||
def forward(self, x):
|
||
x1, x2 = x # high, low
|
||
x1=F.interpolate(x1,size=x2.shape[2:],mode='bilinear', align_corners=False)
|
||
x = torch.cat([x1,x2],dim=1)
|
||
x = self.global_pooling(x)
|
||
x = self.conv1(x)
|
||
x = self.relu(x)
|
||
x = self.conv2(x)
|
||
x = self.sigmod(x)
|
||
x2 = x * x2
|
||
res = x2 + x1
|
||
return res
|
||
|
||
class CAB(nn.Module):
|
||
|
||
def __init__(self, in_channels,out_channels,ratio=8):
|
||
super(CAB, self).__init__()
|
||
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
||
self.sigmod = nn.Sigmoid()
|
||
self.cbr = nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
|
||
|
||
def forward(self, high,low):
|
||
high=F.interpolate(high,size=low.shape[2:],mode='bilinear', align_corners=False)
|
||
x0 = torch.cat([high,low],dim=1)
|
||
x0=self.cbr(x0)
|
||
x = self.global_pooling(x0)
|
||
x = self.conv1(x)
|
||
x = self.relu(x)
|
||
x = self.conv2(x)
|
||
x = self.sigmod(x)
|
||
x = x * x0
|
||
return x
|
||
|
||
class CAM(nn.Module):
|
||
|
||
def __init__(self, in_channels,ratio=8):
|
||
super(CAM, self).__init__()
|
||
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||
self.conv1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, stride=1, padding=0)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, stride=1, padding=0)
|
||
self.sigmod = nn.Sigmoid()
|
||
|
||
|
||
def forward(self, x0):
|
||
x = self.global_pooling(x0)
|
||
x = self.conv1(x)
|
||
x = self.relu(x)
|
||
x = self.conv2(x)
|
||
x = self.sigmod(x)
|
||
x = x * x0
|
||
return x
|
||
|
||
class PAM_Module(nn.Module):
|
||
""" Position attention module"""
|
||
#Ref from SAGAN
|
||
def __init__(self, in_dim):
|
||
super(PAM_Module, self).__init__()
|
||
self.chanel_in = in_dim
|
||
|
||
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
||
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
|
||
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
||
self.gamma = nn.Parameter(torch.zeros(1))
|
||
|
||
self.softmax = nn.Softmax(dim=-1)
|
||
def forward(self, x):
|
||
"""
|
||
inputs :
|
||
x : input feature maps( B X C X H X W)
|
||
returns :
|
||
out : attention value + input feature
|
||
attention: B X (HxW) X (HxW)
|
||
"""
|
||
m_batchsize, C, height, width = x.size()
|
||
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
|
||
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
|
||
energy = torch.bmm(proj_query, proj_key)
|
||
attention = self.softmax(energy)
|
||
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
|
||
|
||
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
||
out = out.view(m_batchsize, C, height, width)
|
||
|
||
out = self.gamma*out + x
|
||
return out
|
||
|
||
|
||
class PAM(nn.Module):
|
||
""" Position attention module"""
|
||
#Ref from SAGAN
|
||
def __init__(self, in_channels):
|
||
super(PAM, self).__init__()
|
||
inter_channels = in_channels // 4
|
||
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
||
nn.BatchNorm2d(inter_channels),
|
||
nn.ReLU())
|
||
|
||
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
||
nn.BatchNorm2d(inter_channels),
|
||
nn.ReLU())
|
||
|
||
self.sa = PAM_Module(inter_channels)
|
||
def forward(self,x):
|
||
feat1 = self.conv5a(x) # 降维4倍维度变成512
|
||
sa_feat = self.sa(feat1) # q和k的维度再降8倍 value不变 #512*64*64
|
||
return sa_feat
|
||
class CARB(nn.Module):
|
||
def __init__(self, in_channels,out_channels,ratio=8):
|
||
super(CARB, self).__init__()
|
||
self.cbr = nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
||
self.sigmod = nn.Sigmoid()
|
||
|
||
|
||
|
||
|
||
def forward(self, x):
|
||
|
||
x0=self.cbr(x)
|
||
x = self.global_pooling(x0)
|
||
x = self.conv1(x)
|
||
x = self.relu(x)
|
||
x = self.conv2(x)
|
||
x = self.sigmod(x)
|
||
x = x * x0
|
||
return x
|
||
|
||
|
||
class unetDecoder(nn.Module):
|
||
def __init__(self,nclass=2, in_filters=[192, 448, 832],out_filters = [64, 128, 320]):
|
||
super(unetDecoder, self).__init__()
|
||
|
||
'''res50
|
||
2 4 8 16 32
|
||
64 256 512 1024 2048
|
||
'''
|
||
|
||
self.nclas=nclass
|
||
self.finnal_channel=512
|
||
|
||
|
||
self.in_filters = in_filters
|
||
self.out_filters = out_filters
|
||
|
||
|
||
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
||
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
||
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
||
|
||
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
||
|
||
self._init_weight()
|
||
|
||
|
||
|
||
def _init_weight(self):
|
||
print("decoder从初始化执行")
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight)
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
def forward(self,inputlist): #512为例
|
||
'''
|
||
mit
|
||
x4:128 128 64
|
||
x8:64 64 128
|
||
x16:32 32 320
|
||
x32: 16 16 512
|
||
|
||
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
||
|
||
resnet
|
||
x4:256 128 128
|
||
x8:512 64 128
|
||
x16:1024 32 320
|
||
x32: 2048 16 512
|
||
'''
|
||
x4,x8,x16,x32=inputlist
|
||
x16= self.up_concat3(x32,x16)
|
||
x8= self.up_concat2(x16,x8)
|
||
x4= self.up_concat1(x8,x4)
|
||
x4=self.classifer(x4)
|
||
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||
return out
|
||
|
||
class unetpamDecoder(nn.Module):
|
||
def __init__(self, nclass=2, in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]):
|
||
super(unetpamDecoder, self).__init__()
|
||
|
||
#in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]
|
||
'''
|
||
256 512 1024 512
|
||
'''
|
||
self.nclas = nclass
|
||
self.finnal_channel = 512
|
||
|
||
self.in_filters = in_filters
|
||
self.out_filters = out_filters
|
||
self.pam = PAM(in_channels=2048)
|
||
|
||
|
||
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
||
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
||
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
||
|
||
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
||
|
||
self._init_weight()
|
||
|
||
def _init_weight(self):
|
||
print("decoder从初始化执行")
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight)
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
def forward(self, inputlist): # 512为例
|
||
'''
|
||
mit
|
||
x4:128 128 64
|
||
x8:64 64 128
|
||
x16:32 32 320
|
||
x32: 16 16 512
|
||
|
||
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
||
|
||
resnet
|
||
x4:256 128 128
|
||
x8:512 64 128
|
||
x16:1024 32 320
|
||
x32: 2048 16 512
|
||
'''
|
||
x4, x8, x16, x32 = inputlist
|
||
x32 = self.pam(x32)
|
||
x16 = self.up_concat3(x32, x16)
|
||
x8 = self.up_concat2(x16, x8)
|
||
x4 = self.up_concat1(x8, x4)
|
||
x4 = self.classifer(x4)
|
||
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||
return out
|
||
|
||
# class unetCAB0Decoder(nn.Module):
|
||
# def __init__(self,nclass=2,in_filters = [192, 448, 832],out_filters = [64, 128, 320]):
|
||
# super(unetCAB0Decoder, self).__init__()
|
||
#
|
||
# '''b1的中间层输出:
|
||
# 64 128 320 512
|
||
# '''
|
||
#
|
||
# self.nclas=nclass
|
||
# self.finnal_channel=512
|
||
#
|
||
#
|
||
# self.in_filters = in_filters
|
||
# self.out_filters = out_filters
|
||
#
|
||
#
|
||
# self.up_concat3=CAB(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
||
# self.up_concat2=CAB(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
||
# self.up_concat1=CAB(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
||
#
|
||
# self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
||
#
|
||
# self._init_weight()
|
||
#
|
||
#
|
||
#
|
||
# def _init_weight(self):
|
||
# print("decoder从初始化执行")
|
||
# for m in self.modules():
|
||
# if isinstance(m, nn.Conv2d):
|
||
# nn.init.kaiming_normal_(m.weight)
|
||
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
# nn.init.constant_(m.weight, 1)
|
||
# nn.init.constant_(m.bias, 0)
|
||
#
|
||
# def forward(self,inputlist): #512为例
|
||
# '''x:32被下采样的 16*16*512
|
||
# 组成:cbr(conv1*1 降维)+ 上采样+concate+
|
||
# '''
|
||
# x4,x8,x16,x32=inputlist
|
||
# x16= self.up_concat3(x32,x16)
|
||
# x8= self.up_concat2(x16,x8)
|
||
# x4= self.up_concat1(x8,x4)
|
||
# x4=self.classifer(x4)
|
||
# out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||
# return out
|
||
|
||
|
||
class unetCARBDecoder(nn.Module):
|
||
def __init__(self, nclass=2, in_filters=[192, 384, 768], out_filters=[64, 128, 256]):
|
||
super(unetCARBDecoder, self).__init__()
|
||
|
||
self.nclas = nclass
|
||
self.finnal_channel = 512
|
||
|
||
self.in_filters = in_filters
|
||
self.out_filters = out_filters
|
||
# self.pam=PAM(in_channels=2048)
|
||
self.pam = nn.Sequential(
|
||
nn.Conv2d(2048, 512, 3, 1, 1),
|
||
nn.BatchNorm2d(512),
|
||
nn.ReLU()
|
||
|
||
)
|
||
self.carb32 = CARB2(512, out_channels=512)
|
||
self.carb16 = CARB2(in_channels=1024, out_channels=256)
|
||
self.carb8 = CARB2(in_channels=512, out_channels=128)
|
||
self.carb4 = CARB2(in_channels=256, out_channels=64)
|
||
|
||
'''
|
||
resnet 256 512 1024 2048
|
||
pam后变成了256 512 1024 512
|
||
carb--1.>128 256 512 512
|
||
carb--2..>64 128 256 512
|
||
|
||
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
|
||
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
|
||
'''
|
||
|
||
self.up_concat3 = up_fusionblock(in_channels=self.in_filters[2], out_channels=self.out_filters[2])
|
||
self.up_concat2 = up_fusionblock(in_channels=self.in_filters[1], out_channels=self.out_filters[1])
|
||
self.up_concat1 = up_fusionblock(in_channels=self.in_filters[0], out_channels=self.out_filters[0])
|
||
|
||
self.classifer = nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
||
|
||
self._init_weight()
|
||
|
||
def _init_weight(self):
|
||
print("decoder从初始化执行")
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight)
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
def forward(self, inputlist): # 512为例
|
||
'''
|
||
mit
|
||
x4:128 128 64
|
||
x8:64 64 128
|
||
x16:32 32 320
|
||
x32: 16 16 512
|
||
|
||
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
||
|
||
resnet
|
||
x4:256 128 128
|
||
x8:512 64 128
|
||
x16:1024 32 320
|
||
x32: 2048 16 512
|
||
'''
|
||
x4, x8, x16, x32 = inputlist
|
||
x32 = self.carb32(self.pam(x32))
|
||
x4 = self.carb4(x4)
|
||
x8 = self.carb8(x8)
|
||
x16 = self.carb16(x16)
|
||
x16 = self.up_concat3(x32, x16)
|
||
x8 = self.up_concat2(x16, x8)
|
||
x4 = self.up_concat1(x8, x4)
|
||
x4 = self.classifer(x4)
|
||
out = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||
return out
|
||
|
||
|
||
class unetpamCARBDecoder(nn.Module):
|
||
def __init__(self,nclass=2, in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]):
|
||
super(unetpamCARBDecoder, self).__init__()
|
||
|
||
|
||
|
||
self.nclas=nclass
|
||
self.finnal_channel=512
|
||
|
||
|
||
self.in_filters = in_filters
|
||
self.out_filters = out_filters
|
||
self.pam=PAM(in_channels=2048)
|
||
# self.pam = nn.Sequential(
|
||
# nn.Conv2d(2048, 512, 3, 1,1),
|
||
# nn.BatchNorm2d(512),
|
||
# nn.ReLU()
|
||
#
|
||
# )
|
||
self.carb32=CARB2(512,out_channels=512)
|
||
self.carb16 = CARB2(in_channels=1024,out_channels=256)
|
||
self.carb8 = CARB2(in_channels=512,out_channels=128)
|
||
self.carb4 = CARB2(in_channels=256,out_channels=64)
|
||
|
||
'''
|
||
resnet 256 512 1024 2048
|
||
pam后变成了256 512 1024 512
|
||
carb--1.>128 256 512 512
|
||
carb--2..>64 128 256 512
|
||
|
||
1.in_filters=[ 256, 512, 1024],out_filters=[64, 128, 256]
|
||
2.in_filters=[ 192, 384, 768],out_filters=[64, 128, 256]
|
||
'''
|
||
|
||
|
||
self.up_concat3=up_fusionblock(in_channels=self.in_filters[2],out_channels= self.out_filters[2])
|
||
self.up_concat2=up_fusionblock(in_channels=self.in_filters[1],out_channels= self.out_filters[1])
|
||
self.up_concat1=up_fusionblock(in_channels=self.in_filters[0],out_channels= self.out_filters[0])
|
||
|
||
self.classifer=nn.Conv2d(self.out_filters[0], self.nclas, kernel_size=1)
|
||
|
||
self._init_weight()
|
||
|
||
|
||
|
||
def _init_weight(self):
|
||
print("decoder从初始化执行")
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight)
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
def forward(self,inputlist): #512为例
|
||
'''
|
||
mit
|
||
x4:128 128 64
|
||
x8:64 64 128
|
||
x16:32 32 320
|
||
x32: 16 16 512
|
||
|
||
in_filters=[192, 448, 832],out_filters = [64, 128, 320])
|
||
|
||
resnet
|
||
x4:256 128 128
|
||
x8:512 64 128
|
||
x16:1024 32 320
|
||
x32: 2048 16 512
|
||
'''
|
||
x4,x8,x16,x32=inputlist
|
||
x32=self.carb32(self.pam(x32))
|
||
x4=self.carb4(x4)
|
||
x8=self.carb8(x8)
|
||
x16=self.carb16(x16)
|
||
x16= self.up_concat3(x32,x16)
|
||
x8= self.up_concat2(x16,x8)
|
||
x4= self.up_concat1(x8,x4)
|
||
x4=self.classifer(x4)
|
||
out= F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
||
return out
|
||
|
||
|
||
def conv_bn_relu(in_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d):
|
||
return nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=False),
|
||
norm_layer(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
|
||
|
||
# ##AFS
|
||
# # 抄的Semantic Segmentation for Remote Sensing Images
|
||
# # Based on Adaptive Feature Selection Network#(letter的)能work才怪#########################
|
||
# class afs():
|
||
#
|
||
|
||
|
||
|
||
|
||
class CARB2(nn.Module):
|
||
def __init__(self, in_channels,out_channels,ratio=8):
|
||
super(CARB2, self).__init__()
|
||
self.cbr = nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
|
||
nn.BatchNorm2d(out_channels),
|
||
nn.ReLU(inplace=True)
|
||
)
|
||
|
||
self.global_avgpooling = nn.AdaptiveAvgPool2d(1)
|
||
self.global_maxpooling = nn.AdaptiveMaxPool2d(1)
|
||
self.conv1 = nn.Conv2d(out_channels, out_channels // ratio, kernel_size=1, stride=1, padding=0)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(out_channels // ratio, out_channels, kernel_size=1, stride=1, padding=0)
|
||
self.sigmod = nn.Sigmoid()
|
||
|
||
|
||
|
||
|
||
def forward(self, x):
|
||
|
||
x0=self.cbr(x)
|
||
xavg = self.global_avgpooling(x0)
|
||
xmax = self.global_maxpooling(x0)
|
||
x=xavg+xmax
|
||
x = self.conv1(x)
|
||
x = self.relu(x)
|
||
x = self.conv2(x)
|
||
x = self.sigmod(x)
|
||
x = x * x0
|
||
return x
|
||
|