Tan_pytorch_segmentation/pytorch_segmentation/PV_MLP/Gmlp.py

11 lines
258 B
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
from model.mlp.g_mlp import gMLP
import torch
num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)