11 lines
258 B
Python
11 lines
258 B
Python
|
from model.mlp.g_mlp import gMLP
|
||
|
import torch
|
||
|
|
||
|
num_tokens=10000
|
||
|
bs=50
|
||
|
len_sen=49
|
||
|
num_layers=6
|
||
|
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
|
||
|
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
|
||
|
output=gmlp(input)
|
||
|
print(output.shape)
|