from attention.MUSEAttention import MUSEAttention import torch from torch import nn from torch.nn import functional as F input=torch.randn(50,49,512) sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8) output=sa(input,input,input) print(output.shape)