from model.backbone.CaiT import CaiT import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = CaiT( img_size= 224, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2 ) output=model(input) print(output.shape)