Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/ParNetAttention.py

42 lines
1.0 KiB
Python

import numpy as np
import torch
from torch import nn
from torch.nn import init
class ParNetAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.sse = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channel, channel, kernel_size=1),
nn.Sigmoid()
)
self.conv1x1 = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=1),
nn.BatchNorm2d(channel)
)
self.conv3x3 = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=3, padding=1),
nn.BatchNorm2d(channel)
)
self.silu = nn.SiLU()
def forward(self, x):
b, c, _, _ = x.size()
x1 = self.conv1x1(x)
x2 = self.conv3x3(x)
x3 = self.sse(x) * x
y = self.silu(x1 + x2 + x3)
return y
# 输入 N C H W, 输出 N C H W
if __name__ == '__main__':
input = torch.randn(3, 512, 7, 7)
pna = ParNetAttention(channel=512)
output = pna(input)
print(output.shape)