from model.backbone.ConViT import VisionTransformer import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = VisionTransformer( num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) output=model(input) print(output.shape)