from model.backbone.DeiT import DistilledVisionTransformer import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = DistilledVisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) output=model(input) print(output[0].shape)