Tan_pytorch_segmentation/pytorch_segmentation/PV_Model/增强头.py

91 lines
4.0 KiB
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import timm
class ConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels),
nn.ReLU6()
)
class ConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBN, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels)
)
class Conv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
super(Conv, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
)
class SeparableConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
norm_layer=nn.BatchNorm2d):
super(SeparableConvBN, self).__init__(
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
groups=in_channels, bias=False),
norm_layer(out_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
)
class FeatureRefinementHead(nn.Module):
def __init__(self, in_channels=64, decoder_channels=64):
super(FeatureRefinementHead, self).__init__()
self.pre_conv = Conv(in_channels, decoder_channels, kernel_size=1)
self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.eps = 1e-8
self.post_conv = ConvBNReLU(in_channels, decoder_channels, kernel_size=3)
self.pa = nn.Sequential(nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3,padding=1,groups=decoder_channels),
nn.Sigmoid())
self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
Conv(decoder_channels, decoder_channels // 16, kernel_size=1),
nn.ReLU6(),
Conv(decoder_channels // 16, decoder_channels, kernel_size=1),
nn.Sigmoid())
self.shortcut = ConvBN(decoder_channels, decoder_channels, kernel_size=1)
self.proj = SeparableConvBN(decoder_channels, decoder_channels, kernel_size=3)
self.act = nn.ReLU6()
def forward(self, x, res):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weights = nn.ReLU()(self.weights)
fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
x = self.post_conv(x)
shortcut = self.shortcut(x)
pa = self.pa(x) * x
ca = self.ca(x) * x
x = pa + ca
x = self.proj(x) + shortcut
x = self.act(x)
return x
# 创建一个输入张量,假设输入通道数为 64输入尺寸为 (32, 32)
x = torch.randn(1, 64, 32, 32)
res = torch.randn(1, 64, 64, 64) # 上采样的参考特征图
# 初始化 FeatureRefinementHead 模块
frh = FeatureRefinementHead()
# 前向传播
output = frh(x, res)
# 打印输出尺寸
print(f"Output shape: {output.shape}")