Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/MUSE-Attention.py

11 lines
256 B
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
from attention.MUSEAttention import MUSEAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,49,512)
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)