31 lines
864 B
Python
31 lines
864 B
Python
|
from model.mlp.repmlp import RepMLP
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
N=4 #batch size
|
||
|
C=512 #input dim
|
||
|
O=1024 #output dim
|
||
|
H=14 #image height
|
||
|
W=14 #image width
|
||
|
h=7 #patch height
|
||
|
w=7 #patch width
|
||
|
fc1_fc2_reduction=1 #reduction ratio
|
||
|
fc3_groups=8 # groups
|
||
|
repconv_kernels=[1,3,5,7] #kernel list
|
||
|
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
|
||
|
x=torch.randn(N,C,H,W)
|
||
|
repmlp.eval()
|
||
|
for module in repmlp.modules():
|
||
|
if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
|
||
|
nn.init.uniform_(module.running_mean, 0, 0.1)
|
||
|
nn.init.uniform_(module.running_var, 0, 0.1)
|
||
|
nn.init.uniform_(module.weight, 0, 0.1)
|
||
|
nn.init.uniform_(module.bias, 0, 0.1)
|
||
|
|
||
|
#training result
|
||
|
out=repmlp(x)
|
||
|
#inference result
|
||
|
repmlp.switch_to_deploy()
|
||
|
deployout = repmlp(x)
|
||
|
|
||
|
print(((deployout-out)**2).sum())
|