350 lines
13 KiB
Python
350 lines
13 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : ceshi
|
|||
|
@Author : qiqq
|
|||
|
@create_time : 2023/6/5 22:29
|
|||
|
"""
|
|||
|
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
|
|||
|
# from torchvision.models.resnet import resnet18
|
|||
|
from compared_experiment.pspnet.model.resnetceshi import resnet18
|
|||
|
|
|||
|
from torch.nn import BatchNorm2d
|
|||
|
from torch.nn import Module, Conv2d, Parameter
|
|||
|
|
|||
|
|
|||
|
def conv3otherRelu(in_planes, out_planes, kernel_size=None, stride=None, padding=None):
|
|||
|
# 3x3 convolution with padding and relu
|
|||
|
if kernel_size is None:
|
|||
|
kernel_size = 3
|
|||
|
assert isinstance(kernel_size, (int, tuple)), 'kernel_size is not in (int, tuple)!'
|
|||
|
|
|||
|
if stride is None:
|
|||
|
stride = 1
|
|||
|
assert isinstance(stride, (int, tuple)), 'stride is not in (int, tuple)!'
|
|||
|
|
|||
|
if padding is None:
|
|||
|
padding = 1
|
|||
|
assert isinstance(padding, (int, tuple)), 'padding is not in (int, tuple)!'
|
|||
|
|
|||
|
return nn.Sequential(
|
|||
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
|
|||
|
nn.ReLU(inplace=True) # inplace=True
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def l2_norm(x):
|
|||
|
return torch.einsum("bcn, bn->bcn", x, 1 / torch.norm(x, p=2, dim=-2))
|
|||
|
|
|||
|
|
|||
|
class Attention(Module):
|
|||
|
def __init__(self, in_places, scale=8, eps=1e-6):
|
|||
|
super(Attention, self).__init__()
|
|||
|
self.gamma = Parameter(torch.zeros(1))
|
|||
|
self.in_places = in_places
|
|||
|
self.l2_norm = l2_norm
|
|||
|
self.eps = eps
|
|||
|
|
|||
|
self.query_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1)
|
|||
|
self.key_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1)
|
|||
|
self.value_conv = Conv2d(in_channels=in_places, out_channels=in_places, kernel_size=1)
|
|||
|
|
|||
|
def forward(self, x):#x:16*16*128
|
|||
|
# Apply the feature map to the queries and keys
|
|||
|
batch_size, chnnels, width, height = x.shape
|
|||
|
Q = self.query_conv(x).view(batch_size, -1, width * height) #16*16*16 --》16*256
|
|||
|
K = self.key_conv(x).view(batch_size, -1, width * height) #16*256
|
|||
|
V = self.value_conv(x).view(batch_size, -1, width * height) #16*16*128 --128*256
|
|||
|
|
|||
|
Q = self.l2_norm(Q).permute(-3, -1, -2)
|
|||
|
K = self.l2_norm(K)
|
|||
|
|
|||
|
tailor_sum = 1 / (width * height + torch.einsum("bnc, bc->bn", Q, torch.sum(K, dim=-1) + self.eps))
|
|||
|
value_sum = torch.einsum("bcn->bc", V).unsqueeze(-1)
|
|||
|
value_sum = value_sum.expand(-1, chnnels, width * height)
|
|||
|
|
|||
|
matrix = torch.einsum('bmn, bcn->bmc', K, V)
|
|||
|
matrix_sum = value_sum + torch.einsum("bnm, bmc->bcn", Q, matrix)
|
|||
|
|
|||
|
weight_value = torch.einsum("bcn, bn->bcn", matrix_sum, tailor_sum)
|
|||
|
weight_value = weight_value.view(batch_size, chnnels, height, width)
|
|||
|
|
|||
|
return (self.gamma * weight_value).contiguous()
|
|||
|
|
|||
|
|
|||
|
class ConvBNReLU(nn.Module):
|
|||
|
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
|||
|
super(ConvBNReLU, self).__init__()
|
|||
|
self.conv = nn.Conv2d(in_chan,
|
|||
|
out_chan,
|
|||
|
kernel_size=ks,
|
|||
|
stride=stride,
|
|||
|
padding=padding,
|
|||
|
bias=False)
|
|||
|
self.bn = BatchNorm2d(out_chan)
|
|||
|
self.relu = nn.ReLU(inplace=True)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = self.conv(x)
|
|||
|
x = self.bn(x)
|
|||
|
x = self.relu(x)
|
|||
|
return x
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
|
|||
|
class UpSample(nn.Module):
|
|||
|
|
|||
|
def __init__(self, n_chan, factor=2):
|
|||
|
super(UpSample, self).__init__()
|
|||
|
out_chan = n_chan * factor * factor
|
|||
|
self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
|
|||
|
self.up = nn.PixelShuffle(factor)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
feat = self.proj(x)
|
|||
|
feat = self.up(feat)
|
|||
|
return feat
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
nn.init.xavier_normal_(self.proj.weight, gain=1.)
|
|||
|
|
|||
|
|
|||
|
class Output(nn.Module):
|
|||
|
def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
|
|||
|
super(Output, self).__init__()
|
|||
|
self.up_factor = up_factor
|
|||
|
out_chan = n_classes * up_factor * up_factor
|
|||
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
|||
|
self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
|
|||
|
self.up = nn.PixelShuffle(up_factor)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = self.conv(x)
|
|||
|
x = self.conv_out(x)
|
|||
|
x = self.up(x)
|
|||
|
return x
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
def get_params(self):
|
|||
|
wd_params, nowd_params = [], []
|
|||
|
for name, module in self.named_modules():
|
|||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|||
|
wd_params.append(module.weight)
|
|||
|
if not module.bias is None:
|
|||
|
nowd_params.append(module.bias)
|
|||
|
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|||
|
nowd_params += list(module.parameters())
|
|||
|
return wd_params, nowd_params
|
|||
|
|
|||
|
|
|||
|
class AttentionEnhancementModule(nn.Module):
|
|||
|
def __init__(self, in_chan, out_chan):
|
|||
|
super(AttentionEnhancementModule, self).__init__()
|
|||
|
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
|||
|
self.conv_atten = Attention(out_chan)
|
|||
|
self.bn_atten = BatchNorm2d(out_chan)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
feat = self.conv(x)
|
|||
|
att = self.conv_atten(feat)
|
|||
|
return self.bn_atten(att)
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None:
|
|||
|
nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
|
|||
|
class ContextPath(nn.Module):
|
|||
|
def __init__(self, *args, **kwargs):
|
|||
|
super(ContextPath, self).__init__()
|
|||
|
self.resnet = resnet18()
|
|||
|
self.arm16 = AttentionEnhancementModule(256, 128)
|
|||
|
self.arm32 = AttentionEnhancementModule(512, 128)
|
|||
|
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
|||
|
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
|||
|
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
|||
|
self.up32 = nn.Upsample(scale_factor=2.)
|
|||
|
self.up16 = nn.Upsample(scale_factor=2.)
|
|||
|
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
feat8, feat16, feat32 = self.resnet(x)
|
|||
|
|
|||
|
avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
|
|||
|
avg = self.conv_avg(avg)
|
|||
|
|
|||
|
feat32_arm = self.arm32(feat32)
|
|||
|
feat32_sum = feat32_arm + avg
|
|||
|
feat32_up = self.up32(feat32_sum)
|
|||
|
feat32_up = self.conv_head32(feat32_up)
|
|||
|
|
|||
|
feat16_arm = self.arm16(feat16)
|
|||
|
feat16_sum = feat16_arm + feat32_up
|
|||
|
feat16_up = self.up16(feat16_sum)
|
|||
|
feat16_up = self.conv_head16(feat16_up)
|
|||
|
|
|||
|
return feat16_up, self.up16(feat16), self.up32(feat32) # x8, x16
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None:
|
|||
|
nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
def get_params(self):
|
|||
|
wd_params, nowd_params = [], []
|
|||
|
for name, module in self.named_modules():
|
|||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|||
|
wd_params.append(module.weight)
|
|||
|
if not module.bias is None:
|
|||
|
nowd_params.append(module.bias)
|
|||
|
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|||
|
nowd_params += list(module.parameters())
|
|||
|
return wd_params, nowd_params
|
|||
|
|
|||
|
|
|||
|
class SpatialPath(nn.Module):
|
|||
|
def __init__(self, *args, **kwargs):
|
|||
|
super(SpatialPath, self).__init__()
|
|||
|
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
|||
|
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
|||
|
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
|||
|
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
feat = self.conv1(x)
|
|||
|
feat = self.conv2(feat)
|
|||
|
feat = self.conv3(feat)
|
|||
|
feat = self.conv_out(feat)
|
|||
|
return feat
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
def get_params(self):
|
|||
|
wd_params, nowd_params = [], []
|
|||
|
for name, module in self.named_modules():
|
|||
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|||
|
wd_params.append(module.weight)
|
|||
|
if not module.bias is None:
|
|||
|
nowd_params.append(module.bias)
|
|||
|
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|||
|
nowd_params += list(module.parameters())
|
|||
|
return wd_params, nowd_params
|
|||
|
|
|||
|
|
|||
|
class FeatureAggregationModule(nn.Module):
|
|||
|
def __init__(self, in_chan, out_chan):
|
|||
|
super(FeatureAggregationModule, self).__init__()
|
|||
|
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
|||
|
self.conv_atten = Attention(out_chan)
|
|||
|
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, fsp, fcp):
|
|||
|
fcat = torch.cat([fsp, fcp], dim=1)
|
|||
|
feat = self.convblk(fcat)
|
|||
|
atten = self.conv_atten(feat)
|
|||
|
feat_atten = torch.mul(feat, atten)
|
|||
|
feat_out = feat_atten + feat
|
|||
|
return feat_out
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
def get_params(self):
|
|||
|
wd_params, nowd_params = [], []
|
|||
|
for name, module in self.named_modules():
|
|||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|||
|
wd_params.append(module.weight)
|
|||
|
if not module.bias is None:
|
|||
|
nowd_params.append(module.bias)
|
|||
|
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|||
|
nowd_params += list(module.parameters())
|
|||
|
return wd_params, nowd_params
|
|||
|
|
|||
|
|
|||
|
class ABCNet(nn.Module):
|
|||
|
def __init__(self, band, n_classes):
|
|||
|
super(ABCNet, self).__init__()
|
|||
|
self.name = 'ABCNet'
|
|||
|
self.cp = ContextPath()
|
|||
|
self.sp = SpatialPath()
|
|||
|
self.fam = FeatureAggregationModule(256, 256)
|
|||
|
self.conv_out = Output(256, 256, n_classes, up_factor=8)
|
|||
|
if self.training:
|
|||
|
self.conv_out16 = Output(256, 64, n_classes, up_factor=8)
|
|||
|
self.conv_out32 = Output(512, 64, n_classes, up_factor=16)
|
|||
|
self.init_weight()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
H, W = x.size()[2:]
|
|||
|
feat_cp8, feat_cp16, feat_cp32 = self.cp(x)
|
|||
|
feat_sp = self.sp(x)
|
|||
|
feat_fuse = self.fam(feat_sp, feat_cp8)
|
|||
|
|
|||
|
feat_out = self.conv_out(feat_fuse)
|
|||
|
if self.training:
|
|||
|
feat_out16 = self.conv_out16(feat_cp16)
|
|||
|
feat_out32 = self.conv_out32(feat_cp32)
|
|||
|
return feat_out, feat_out16, feat_out32
|
|||
|
# feat_out = feat_out.argmax(dim=1)
|
|||
|
return feat_out
|
|||
|
|
|||
|
def init_weight(self):
|
|||
|
for ly in self.children():
|
|||
|
if isinstance(ly, nn.Conv2d):
|
|||
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|||
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|||
|
|
|||
|
def get_params(self):
|
|||
|
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
|||
|
for name, child in self.named_children():
|
|||
|
child_wd_params, child_nowd_params = child.get_params()
|
|||
|
if isinstance(child, (FeatureAggregationModule, Output)):
|
|||
|
lr_mul_wd_params += child_wd_params
|
|||
|
lr_mul_nowd_params += child_nowd_params
|
|||
|
else:
|
|||
|
wd_params += child_wd_params
|
|||
|
nowd_params += child_nowd_params
|
|||
|
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
net = ABCNet(3, 19)
|
|||
|
|
|||
|
in_ten = torch.randn(4, 3, 512, 512)
|
|||
|
out = net(in_ten)
|
|||
|
print(out.shape)
|
|||
|
|