from model.mlp.mlp_mixer import MlpMixer import torch mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024) input=torch.randn(50,3,40,40) output=mlp_mixer(input) print(output.shape)