Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Simplified-Self-Attention.py

8 lines
235 B
Python

from attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch
input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)