ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/pspnet/model/ceshi.py

350 lines
13 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : ceshi
@Author : qiqq
@create_time : 2023/6/5 22:29
"""
import torch
import torch.nn as nn
# from torchvision.models.resnet import resnet18
from compared_experiment.pspnet.model.resnetceshi import resnet18
from torch.nn import BatchNorm2d
from torch.nn import Module, Conv2d, Parameter
def conv3otherRelu(in_planes, out_planes, kernel_size=None, stride=None, padding=None):
# 3x3 convolution with padding and relu
if kernel_size is None:
kernel_size = 3
assert isinstance(kernel_size, (int, tuple)), 'kernel_size is not in (int, tuple)!'
if stride is None:
stride = 1
assert isinstance(stride, (int, tuple)), 'stride is not in (int, tuple)!'
if padding is None:
padding = 1
assert isinstance(padding, (int, tuple)), 'padding is not in (int, tuple)!'
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
nn.ReLU(inplace=True) # inplace=True
)
def l2_norm(x):
return torch.einsum("bcn, bn->bcn", x, 1 / torch.norm(x, p=2, dim=-2))
class Attention(Module):
def __init__(self, in_places, scale=8, eps=1e-6):
super(Attention, self).__init__()
self.gamma = Parameter(torch.zeros(1))
self.in_places = in_places
self.l2_norm = l2_norm
self.eps = eps
self.query_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1)
self.key_conv = Conv2d(in_channels=in_places, out_channels=in_places // scale, kernel_size=1)
self.value_conv = Conv2d(in_channels=in_places, out_channels=in_places, kernel_size=1)
def forward(self, x):#x16*16*128
# Apply the feature map to the queries and keys
batch_size, chnnels, width, height = x.shape
Q = self.query_conv(x).view(batch_size, -1, width * height) #16*16*16 --》16*256
K = self.key_conv(x).view(batch_size, -1, width * height) #16*256
V = self.value_conv(x).view(batch_size, -1, width * height) #16*16*128 --128*256
Q = self.l2_norm(Q).permute(-3, -1, -2)
K = self.l2_norm(K)
tailor_sum = 1 / (width * height + torch.einsum("bnc, bc->bn", Q, torch.sum(K, dim=-1) + self.eps))
value_sum = torch.einsum("bcn->bc", V).unsqueeze(-1)
value_sum = value_sum.expand(-1, chnnels, width * height)
matrix = torch.einsum('bmn, bcn->bmc', K, V)
matrix_sum = value_sum + torch.einsum("bnm, bmc->bcn", Q, matrix)
weight_value = torch.einsum("bcn, bn->bcn", matrix_sum, tailor_sum)
weight_value = weight_value.view(batch_size, chnnels, height, width)
return (self.gamma * weight_value).contiguous()
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
bias=False)
self.bn = BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class UpSample(nn.Module):
def __init__(self, n_chan, factor=2):
super(UpSample, self).__init__()
out_chan = n_chan * factor * factor
self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
self.up = nn.PixelShuffle(factor)
self.init_weight()
def forward(self, x):
feat = self.proj(x)
feat = self.up(feat)
return feat
def init_weight(self):
nn.init.xavier_normal_(self.proj.weight, gain=1.)
class Output(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
super(Output, self).__init__()
self.up_factor = up_factor
out_chan = n_classes * up_factor * up_factor
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
self.up = nn.PixelShuffle(up_factor)
self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
x = self.up(x)
return x
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class AttentionEnhancementModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(AttentionEnhancementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = Attention(out_chan)
self.bn_atten = BatchNorm2d(out_chan)
self.init_weight()
def forward(self, x):
feat = self.conv(x)
att = self.conv_atten(feat)
return self.bn_atten(att)
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None:
nn.init.constant_(ly.bias, 0)
class ContextPath(nn.Module):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = resnet18()
self.arm16 = AttentionEnhancementModule(256, 128)
self.arm32 = AttentionEnhancementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
self.up32 = nn.Upsample(scale_factor=2.)
self.up16 = nn.Upsample(scale_factor=2.)
self.init_weight()
def forward(self, x):
feat8, feat16, feat32 = self.resnet(x)
avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
avg = self.conv_avg(avg)
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg
feat32_up = self.up32(feat32_sum)
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = self.up16(feat16_sum)
feat16_up = self.conv_head16(feat16_up)
return feat16_up, self.up16(feat16), self.up32(feat32) # x8, x16
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None:
nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class SpatialPath(nn.Module):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class FeatureAggregationModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(FeatureAggregationModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv_atten = Attention(out_chan)
self.init_weight()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = self.conv_atten(feat)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.modules.batchnorm._BatchNorm):
nowd_params += list(module.parameters())
return wd_params, nowd_params
class ABCNet(nn.Module):
def __init__(self, band, n_classes):
super(ABCNet, self).__init__()
self.name = 'ABCNet'
self.cp = ContextPath()
self.sp = SpatialPath()
self.fam = FeatureAggregationModule(256, 256)
self.conv_out = Output(256, 256, n_classes, up_factor=8)
if self.training:
self.conv_out16 = Output(256, 64, n_classes, up_factor=8)
self.conv_out32 = Output(512, 64, n_classes, up_factor=16)
self.init_weight()
def forward(self, x):
H, W = x.size()[2:]
feat_cp8, feat_cp16, feat_cp32 = self.cp(x)
feat_sp = self.sp(x)
feat_fuse = self.fam(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
if self.training:
feat_out16 = self.conv_out16(feat_cp16)
feat_out32 = self.conv_out32(feat_cp32)
return feat_out, feat_out16, feat_out32
# feat_out = feat_out.argmax(dim=1)
return feat_out
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, (FeatureAggregationModule, Output)):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
if __name__ == "__main__":
net = ABCNet(3, 19)
in_ten = torch.randn(4, 3, 512, 512)
out = net(in_ten)
print(out.shape)