91 lines
4.0 KiB
Python
91 lines
4.0 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange, repeat
|
||
|
||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||
import timm
|
||
class ConvBNReLU(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
|
||
super(ConvBNReLU, self).__init__(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
|
||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
|
||
norm_layer(out_channels),
|
||
nn.ReLU6()
|
||
)
|
||
|
||
|
||
class ConvBN(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
|
||
super(ConvBN, self).__init__(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
|
||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
|
||
norm_layer(out_channels)
|
||
)
|
||
|
||
|
||
class Conv(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
|
||
super(Conv, self).__init__(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
|
||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
|
||
)
|
||
|
||
class SeparableConvBN(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
|
||
norm_layer=nn.BatchNorm2d):
|
||
super(SeparableConvBN, self).__init__(
|
||
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
|
||
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
|
||
groups=in_channels, bias=False),
|
||
norm_layer(out_channels),
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||
)
|
||
|
||
|
||
class FeatureRefinementHead(nn.Module):
|
||
def __init__(self, in_channels=64, decoder_channels=64):
|
||
super(FeatureRefinementHead, self).__init__()
|
||
self.pre_conv = Conv(in_channels, decoder_channels, kernel_size=1)
|
||
|
||
self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
|
||
self.eps = 1e-8
|
||
|
||
self.post_conv = ConvBNReLU(in_channels, decoder_channels, kernel_size=3)
|
||
|
||
self.pa = nn.Sequential(nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3,padding=1,groups=decoder_channels),
|
||
nn.Sigmoid())
|
||
self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
|
||
Conv(decoder_channels, decoder_channels // 16, kernel_size=1),
|
||
nn.ReLU6(),
|
||
Conv(decoder_channels // 16, decoder_channels, kernel_size=1),
|
||
nn.Sigmoid())
|
||
self.shortcut = ConvBN(decoder_channels, decoder_channels, kernel_size=1)
|
||
self.proj = SeparableConvBN(decoder_channels, decoder_channels, kernel_size=3)
|
||
self.act = nn.ReLU6()
|
||
def forward(self, x, res):
|
||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||
weights = nn.ReLU()(self.weights)
|
||
fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
|
||
x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
|
||
x = self.post_conv(x)
|
||
shortcut = self.shortcut(x)
|
||
pa = self.pa(x) * x
|
||
ca = self.ca(x) * x
|
||
x = pa + ca
|
||
x = self.proj(x) + shortcut
|
||
x = self.act(x)
|
||
return x
|
||
|
||
# 创建一个输入张量,假设输入通道数为 64,输入尺寸为 (32, 32)
|
||
x = torch.randn(1, 64, 32, 32)
|
||
res = torch.randn(1, 64, 64, 64) # 上采样的参考特征图
|
||
|
||
# 初始化 FeatureRefinementHead 模块
|
||
frh = FeatureRefinementHead()
|
||
|
||
# 前向传播
|
||
output = frh(x, res)
|
||
|
||
# 打印输出尺寸
|
||
print(f"Output shape: {output.shape}") |