from attention.AFT import AFT_FULL import torch from torch import nn from torch.nn import functional as F input=torch.randn(50,49,512) aft_full = AFT_FULL(d_model=512, n=49) output=aft_full(input) print(output.shape)