from model.attention.Axial_attention import AxialImageTransformer import torch if __name__ == '__main__': input = torch.randn(3, 128, 7, 7) model = AxialImageTransformer( dim=128, depth=12, reversible=True ) outputs = model(input) print(outputs.shape)