Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Simplified-Self-Attention.py

8 lines
235 B
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
from attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch
input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)