Tan_pytorch_segmentation/pytorch_segmentation/PV_MLP/MLP-Mixer.py

6 lines
278 B
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
from model.mlp.mlp_mixer import MlpMixer
import torch
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)