Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Efficient Multi-Head Self-A...

18 lines
924 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28
论文地址https://arxiv.org/abs/2105.13677
这是南大5月28日在arXiv上上传的一篇文章。本文解决的主要是SA的两个痛点问题
1Self-Attention的计算复杂度和nn为空间维度的大小呈平方关系
2每个head只有q,k,v的部分信息如果q,k,v的维度太小那么就会导致获取不到连续的信息
从而导致性能损失。这篇文章给出的思路也非常简单在SA中在FC之前用了一个卷积来降低了空间的维度从而得到空间维度上更小的K和V。
"""
from attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F
input = torch.randn(50, 64, 512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8, H=8, W=8, ratio=2, apply_transform=True)
output = emsa(input, input, input)
print(output.shape)