from model.backbone.CrossViT import VisionTransformer import torch from torch import nn if __name__ == "__main__": input=torch.randn(1,3,224,224) model = VisionTransformer( img_size=[240, 224], patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) output=model(input) print(output.shape)