18 lines
512 B
Python
18 lines
512 B
Python
from model.mlp.vip-mlp import VisionPermutator
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
if __name__ == '__main__':
|
|
input=torch.randn(1,3,224,224)
|
|
model = VisionPermutator(
|
|
layers=[4, 3, 8, 3],
|
|
embed_dims=[384, 384, 384, 384],
|
|
patch_size=14,
|
|
transitions=[False, False, False, False],
|
|
segment_dim=[16, 16, 16, 16],
|
|
mlp_ratios=[3, 3, 3, 3],
|
|
mlp_fn=WeightedPermuteMLP
|
|
)
|
|
output=model(input)
|
|
print(output.shape) |