from model.mlp.vip-mlp import VisionPermutator import torch from torch import nn from torch.nn import functional as F if __name__ == '__main__': input=torch.randn(1,3,224,224) model = VisionPermutator( layers=[4, 3, 8, 3], embed_dims=[384, 384, 384, 384], patch_size=14, transitions=[False, False, False, False], segment_dim=[16, 16, 16, 16], mlp_ratios=[3, 3, 3, 3], mlp_fn=WeightedPermuteMLP ) output=model(input) print(output.shape)