42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import init
|
|
|
|
|
|
class ParNetAttention(nn.Module):
|
|
|
|
def __init__(self, channel=512):
|
|
super().__init__()
|
|
self.sse = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(channel, channel, kernel_size=1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
self.conv1x1 = nn.Sequential(
|
|
nn.Conv2d(channel, channel, kernel_size=1),
|
|
nn.BatchNorm2d(channel)
|
|
)
|
|
self.conv3x3 = nn.Sequential(
|
|
nn.Conv2d(channel, channel, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(channel)
|
|
)
|
|
self.silu = nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
b, c, _, _ = x.size()
|
|
x1 = self.conv1x1(x)
|
|
x2 = self.conv3x3(x)
|
|
x3 = self.sse(x) * x
|
|
y = self.silu(x1 + x2 + x3)
|
|
return y
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
input = torch.randn(3, 512, 7, 7)
|
|
pna = ParNetAttention(channel=512)
|
|
output = pna(input)
|
|
print(output.shape)
|