#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : ceshi @Author : qiqq @create_time : 2023/6/5 22:29 """ import torch import torch.nn as nn # from torchvision.models.resnet import resnet18 from compared_experiment.pspnet.model.resnetceshi import resnet18 from torch.nn import BatchNorm2d from torch.nn import Module, Conv2d, Parameter def conv3otherRelu(in_planes, out_planes, kernel_size=None, stride=None, padding=None): # 3x3 convolution with padding and relu if kernel_size is None: kernel_size = 3 assert isinstance(kernel_size, (int, tuple)), 'kernel_size is not in (int, tuple)!' if stride is None: stride = 1 assert isinstance(stride, (int, tuple)), 'stride is not in (int, tuple)!' if padding is None: padding = 1 assert isinstance(padding, (int, tuple)), 'padding is not in (int, tuple)!' return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True), nn.ReLU(inplace=True) # inplace=True ) def l2_norm(x): return torch.einsum("bcn, bn->bcn", x, 1 / torch.norm(x, p=2, dim=-2)) class Attention(Module): def __init__(self, in_places, scale=8, eps=1e-6): super(Attention, self).__init__() self.gamma = Parameter(torch.zeros(1)) self.in_places = in_places self.l2_norm = l2_norm self.eps = eps self.query_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1) self.key_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1) self.value_conv = Conv2d(in_channels=in_places, out_channels=in_places, kernel_size=1) def forward(self, x):#x:16*16*128 # Apply the feature map to the queries and keys batch_size, chnnels, width, height = x.shape Q = self.query_conv(x).view(batch_size, -1, width * height) #16*16*16 --》16*256 K = self.key_conv(x).view(batch_size, -1, width * height) #16*256 V = self.value_conv(x).view(batch_size, -1, width * height) #16*16*128 --128*256 Q = self.l2_norm(Q).permute(-3, -1, -2) K = self.l2_norm(K) tailor_sum = 1 / (width * height + torch.einsum("bnc, bc->bn", Q, torch.sum(K, dim=-1) + self.eps)) value_sum = torch.einsum("bcn->bc", V).unsqueeze(-1) value_sum = value_sum.expand(-1, chnnels, width * height) matrix = torch.einsum('bmn, bcn->bmc', K, V) matrix_sum = value_sum + torch.einsum("bnm, bmc->bcn", Q, matrix) weight_value = torch.einsum("bcn, bn->bcn", matrix_sum, tailor_sum) weight_value = weight_value.view(batch_size, chnnels, height, width) return (self.gamma * weight_value).contiguous() class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) self.bn = BatchNorm2d(out_chan) self.relu = nn.ReLU(inplace=True) self.init_weight() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class UpSample(nn.Module): def __init__(self, n_chan, factor=2): super(UpSample, self).__init__() out_chan = n_chan * factor * factor self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) self.up = nn.PixelShuffle(factor) self.init_weight() def forward(self, x): feat = self.proj(x) feat = self.up(feat) return feat def init_weight(self): nn.init.xavier_normal_(self.proj.weight, gain=1.) class Output(nn.Module): def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs): super(Output, self).__init__() self.up_factor = up_factor out_chan = n_classes * up_factor * up_factor self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True) self.up = nn.PixelShuffle(up_factor) self.init_weight() def forward(self, x): x = self.conv(x) x = self.conv_out(x) x = self.up(x) return x def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.modules.batchnorm._BatchNorm): nowd_params += list(module.parameters()) return wd_params, nowd_params class AttentionEnhancementModule(nn.Module): def __init__(self, in_chan, out_chan): super(AttentionEnhancementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv_atten = Attention(out_chan) self.bn_atten = BatchNorm2d(out_chan) self.init_weight() def forward(self, x): feat = self.conv(x) att = self.conv_atten(feat) return self.bn_atten(att) def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) class ContextPath(nn.Module): def __init__(self, *args, **kwargs): super(ContextPath, self).__init__() self.resnet = resnet18() self.arm16 = AttentionEnhancementModule(256, 128) self.arm32 = AttentionEnhancementModule(512, 128) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) self.up32 = nn.Upsample(scale_factor=2.) self.up16 = nn.Upsample(scale_factor=2.) self.init_weight() def forward(self, x): feat8, feat16, feat32 = self.resnet(x) avg = torch.mean(feat32, dim=(2, 3), keepdim=True) avg = self.conv_avg(avg) feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg feat32_up = self.up32(feat32_sum) feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up feat16_up = self.up16(feat16_sum) feat16_up = self.conv_head16(feat16_up) return feat16_up, self.up16(feat16), self.up32(feat32) # x8, x16 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.modules.batchnorm._BatchNorm): nowd_params += list(module.parameters()) return wd_params, nowd_params class SpatialPath(nn.Module): def __init__(self, *args, **kwargs): super(SpatialPath, self).__init__() self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) self.init_weight() def forward(self, x): feat = self.conv1(x) feat = self.conv2(feat) feat = self.conv3(feat) feat = self.conv_out(feat) return feat def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.modules.batchnorm._BatchNorm): nowd_params += list(module.parameters()) return wd_params, nowd_params class FeatureAggregationModule(nn.Module): def __init__(self, in_chan, out_chan): super(FeatureAggregationModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.conv_atten = Attention(out_chan) self.init_weight() def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) atten = self.conv_atten(feat) feat_atten = torch.mul(feat, atten) feat_out = feat_atten + feat return feat_out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) elif isinstance(module, nn.modules.batchnorm._BatchNorm): nowd_params += list(module.parameters()) return wd_params, nowd_params class ABCNet(nn.Module): def __init__(self, band, n_classes): super(ABCNet, self).__init__() self.name = 'ABCNet' self.cp = ContextPath() self.sp = SpatialPath() self.fam = FeatureAggregationModule(256, 256) self.conv_out = Output(256, 256, n_classes, up_factor=8) if self.training: self.conv_out16 = Output(256, 64, n_classes, up_factor=8) self.conv_out32 = Output(512, 64, n_classes, up_factor=16) self.init_weight() def forward(self, x): H, W = x.size()[2:] feat_cp8, feat_cp16, feat_cp32 = self.cp(x) feat_sp = self.sp(x) feat_fuse = self.fam(feat_sp, feat_cp8) feat_out = self.conv_out(feat_fuse) if self.training: feat_out16 = self.conv_out16(feat_cp16) feat_out32 = self.conv_out32(feat_cp32) return feat_out, feat_out16, feat_out32 # feat_out = feat_out.argmax(dim=1) return feat_out def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) if not ly.bias is None: nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for name, child in self.named_children(): child_wd_params, child_nowd_params = child.get_params() if isinstance(child, (FeatureAggregationModule, Output)): lr_mul_wd_params += child_wd_params lr_mul_nowd_params += child_nowd_params else: wd_params += child_wd_params nowd_params += child_nowd_params return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params if __name__ == "__main__": net = ABCNet(3, 19) in_ten = torch.randn(4, 3, 512, 512) out = net(in_ten) print(out.shape)