from model.backbone.PIT import PoolingTransformer import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = PoolingTransformer( image_size=224, patch_size=14, stride=7, base_dims=[64, 64, 64], depth=[3, 6, 4], heads=[4, 8, 16], mlp_ratio=4 ) output=model(input) print(output.shape)