from model.backbone.DViT import DeepVisionTransformer import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = DeepVisionTransformer( patch_size=16, embed_dim=384, depth=[False] * 16, apply_transform=[False] * 0 + [True] * 32, num_heads=12, mlp_ratio=3, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), ) output=model(input) print(output.shape)