from model.backbone.Container import VisionTransformer import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = VisionTransformer( img_size=[224, 56, 28, 14], patch_size=[4, 2, 2, 2], embed_dim=[64, 128, 320, 512], depth=[3, 4, 8, 3], num_heads=16, mlp_ratio=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) output=model(input) print(output.shape)