6 lines
278 B
Python
6 lines
278 B
Python
|
from model.mlp.mlp_mixer import MlpMixer
|
||
|
import torch
|
||
|
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
|
||
|
input=torch.randn(50,3,40,40)
|
||
|
output=mlp_mixer(input)
|
||
|
print(output.shape)
|