from attention.TripletAttention import TripletAttention import torch from torch import nn from torch.nn import functional as F input=torch.randn(50,512,7,7) triplet = TripletAttention() output=triplet(input) print(output.shape)