38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class LSKblock(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
|
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
|
|
self.conv1 = nn.Conv2d(dim, dim // 2, 1)
|
|
self.conv2 = nn.Conv2d(dim, dim // 2, 1)
|
|
self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
|
|
self.conv = nn.Conv2d(dim // 2, dim, 1)
|
|
|
|
def forward(self, x):
|
|
attn1 = self.conv0(x)
|
|
attn2 = self.conv_spatial(attn1)
|
|
|
|
attn1 = self.conv1(attn1)
|
|
attn2 = self.conv2(attn2)
|
|
|
|
attn = torch.cat([attn1, attn2], dim=1)
|
|
avg_attn = torch.mean(attn, dim=1, keepdim=True)
|
|
max_attn, _ = torch.max(attn, dim=1, keepdim=True)
|
|
agg = torch.cat([avg_attn, max_attn], dim=1)
|
|
sig = self.conv_squeeze(agg).sigmoid()
|
|
attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1)
|
|
attn = self.conv(attn)
|
|
return x * attn
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
block = LSKblock(64).cuda()
|
|
input = torch.rand(1, 64, 64, 64).cuda()
|
|
output = block(input)
|
|
print(input.size(), output.size())
|