Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Selective-Kernel(SK)-Attent...

19 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Selective Kernel Networks---CVPR2019
论文地址https://arxiv.org/pdf/1903.06586.pdf
这是CVPR2019的一篇文章致敬了SENet的思想。在传统的CNN中每一个卷积层都是用相同大小的卷积核限制了模型的表达能力而Inception这种“更宽”的模型结构也验证了用多个不同的卷积核进行学习确实可以提升模型的表达能力。作者借鉴了SENet的思想通过动态计算每个卷积核得到通道的权重动态的将各个卷积核的结果进行融合。
个人认为之所以所这篇文章也能够称之为lightweight是因为对不同kernel的特征进行通道注意力的时候是参数共享的i.e. 因为在做Attention之前首先将特征进行了融合所以不同卷积核的结果共享一个SE模块的参数
本文的方法分为三个部分Split,Fuse,Select。Split就是一个multi-branch的操作用不同的卷积核进行卷积得到不同的特征Fuse部分就是用SE的结构获取通道注意力的矩阵(N个卷积核就可以得到N个注意力矩阵这步操作对所有的特征参数共享)这样就可以得到不同kernel经过SE之后的特征Select操作就是将这几个特征进行相加。
"""
from attention.SKAttention import SKAttention
import torch
input = torch.randn(50, 512, 7, 7)
se = SKAttention(channel=512, reduction=8)
output = se(input)
print(output.shape)