Tan_pytorch_segmentation/pytorch_segmentation/PV_Model/增强头.py

91 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import timm
class ConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels),
nn.ReLU6()
)
class ConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBN, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels)
)
class Conv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
super(Conv, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
)
class SeparableConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
norm_layer=nn.BatchNorm2d):
super(SeparableConvBN, self).__init__(
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
groups=in_channels, bias=False),
norm_layer(out_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
)
class FeatureRefinementHead(nn.Module):
def __init__(self, in_channels=64, decoder_channels=64):
super(FeatureRefinementHead, self).__init__()
self.pre_conv = Conv(in_channels, decoder_channels, kernel_size=1)
self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.eps = 1e-8
self.post_conv = ConvBNReLU(in_channels, decoder_channels, kernel_size=3)
self.pa = nn.Sequential(nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3,padding=1,groups=decoder_channels),
nn.Sigmoid())
self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
Conv(decoder_channels, decoder_channels // 16, kernel_size=1),
nn.ReLU6(),
Conv(decoder_channels // 16, decoder_channels, kernel_size=1),
nn.Sigmoid())
self.shortcut = ConvBN(decoder_channels, decoder_channels, kernel_size=1)
self.proj = SeparableConvBN(decoder_channels, decoder_channels, kernel_size=3)
self.act = nn.ReLU6()
def forward(self, x, res):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weights = nn.ReLU()(self.weights)
fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
x = self.post_conv(x)
shortcut = self.shortcut(x)
pa = self.pa(x) * x
ca = self.ca(x) * x
x = pa + ca
x = self.proj(x) + shortcut
x = self.act(x)
return x
# 创建一个输入张量,假设输入通道数为 64输入尺寸为 (32, 32)
x = torch.randn(1, 64, 32, 32)
res = torch.randn(1, 64, 64, 64) # 上采样的参考特征图
# 初始化 FeatureRefinementHead 模块
frh = FeatureRefinementHead()
# 前向传播
output = frh(x, res)
# 打印输出尺寸
print(f"Output shape: {output.shape}")