32 lines
800 B
Python
32 lines
800 B
Python
|
# LKA
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
class LKA(nn.Module):
|
||
|
def __init__(self, dim):
|
||
|
super().__init__()
|
||
|
# 深度卷积
|
||
|
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
||
|
# 深度空洞卷积
|
||
|
self.conv_spatial = nn.Conv2d(
|
||
|
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
|
||
|
# 逐点卷积
|
||
|
self.conv1 = nn.Conv2d(dim, dim, 1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
u = x.clone()
|
||
|
attn = self.conv0(x)
|
||
|
attn = self.conv_spatial(attn)
|
||
|
attn = self.conv1(attn)
|
||
|
|
||
|
# 注意力操作
|
||
|
return u * attn
|
||
|
|
||
|
|
||
|
# 输入 N C H W, 输出 N C H W
|
||
|
if __name__ == '__main__':
|
||
|
block = LKA(64)
|
||
|
input = torch.rand(1, 64, 64, 64)
|
||
|
output = block(input)
|
||
|
print(input.size(), output.size())
|