from model.attention.UFOAttention import * import torch from torch import nn from torch.nn import functional as F if __name__ == '__main__': input = torch.randn(50, 49, 512) ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8) output = ufo(input, input, input) print(output.shape) # [50, 49, 512]