18 lines
924 B
Python
18 lines
924 B
Python
|
"""
|
|||
|
ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28
|
|||
|
|
|||
|
论文地址:https://arxiv.org/abs/2105.13677
|
|||
|
这是南大5月28日在arXiv上上传的一篇文章。本文解决的主要是SA的两个痛点问题:
|
|||
|
(1)Self-Attention的计算复杂度和n(n为空间维度的大小)呈平方关系;
|
|||
|
(2)每个head只有q,k,v的部分信息,如果q,k,v的维度太小,那么就会导致获取不到连续的信息,
|
|||
|
从而导致性能损失。这篇文章给出的思路也非常简单,在SA中,在FC之前,用了一个卷积来降低了空间的维度,从而得到空间维度上更小的K和V。
|
|||
|
"""
|
|||
|
from attention.EMSA import EMSA
|
|||
|
import torch
|
|||
|
from torch import nn
|
|||
|
from torch.nn import functional as F
|
|||
|
|
|||
|
input = torch.randn(50, 64, 512)
|
|||
|
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8, H=8, W=8, ratio=2, apply_transform=True)
|
|||
|
output = emsa(input, input, input)
|
|||
|
print(output.shape)
|