import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import timm class ConvBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBNReLU, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels), nn.ReLU6() ) class ConvBN(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBN, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels) ) class Conv(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False): super(Conv, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2) ) class SeparableConvBN(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): super(SeparableConvBN, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), norm_layer(out_channels), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) ) class FeatureRefinementHead(nn.Module): def __init__(self, in_channels=64, decoder_channels=64): super(FeatureRefinementHead, self).__init__() self.pre_conv = Conv(in_channels, decoder_channels, kernel_size=1) self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.eps = 1e-8 self.post_conv = ConvBNReLU(in_channels, decoder_channels, kernel_size=3) self.pa = nn.Sequential(nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3,padding=1,groups=decoder_channels), nn.Sigmoid()) self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1), Conv(decoder_channels, decoder_channels // 16, kernel_size=1), nn.ReLU6(), Conv(decoder_channels // 16, decoder_channels, kernel_size=1), nn.Sigmoid()) self.shortcut = ConvBN(decoder_channels, decoder_channels, kernel_size=1) self.proj = SeparableConvBN(decoder_channels, decoder_channels, kernel_size=3) self.act = nn.ReLU6() def forward(self, x, res): x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weights = nn.ReLU()(self.weights) fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x x = self.post_conv(x) shortcut = self.shortcut(x) pa = self.pa(x) * x ca = self.ca(x) * x x = pa + ca x = self.proj(x) + shortcut x = self.act(x) return x # 创建一个输入张量,假设输入通道数为 64,输入尺寸为 (32, 32) x = torch.randn(1, 64, 32, 32) res = torch.randn(1, 64, 64, 64) # 上采样的参考特征图 # 初始化 FeatureRefinementHead 模块 frh = FeatureRefinementHead() # 前向传播 output = frh(x, res) # 打印输出尺寸 print(f"Output shape: {output.shape}")