from model.attention.CrissCrossAttention import CrissCrossAttention import torch if __name__ == '__main__': input = torch.randn(3, 64, 7, 7) model = CrissCrossAttention(64) outputs = model(input) print(outputs.shape)