from model.backbone.CeiT import CeIT import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = CeIT( hybrid_backbone=Image2Tokens(), patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) output=model(input) print(output.shape)