31 lines
864 B
Python
31 lines
864 B
Python
from model.mlp.repmlp import RepMLP
|
|
import torch
|
|
from torch import nn
|
|
|
|
N=4 #batch size
|
|
C=512 #input dim
|
|
O=1024 #output dim
|
|
H=14 #image height
|
|
W=14 #image width
|
|
h=7 #patch height
|
|
w=7 #patch width
|
|
fc1_fc2_reduction=1 #reduction ratio
|
|
fc3_groups=8 # groups
|
|
repconv_kernels=[1,3,5,7] #kernel list
|
|
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
|
|
x=torch.randn(N,C,H,W)
|
|
repmlp.eval()
|
|
for module in repmlp.modules():
|
|
if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
|
|
nn.init.uniform_(module.running_mean, 0, 0.1)
|
|
nn.init.uniform_(module.running_var, 0, 0.1)
|
|
nn.init.uniform_(module.weight, 0, 0.1)
|
|
nn.init.uniform_(module.bias, 0, 0.1)
|
|
|
|
#training result
|
|
out=repmlp(x)
|
|
#inference result
|
|
repmlp.switch_to_deploy()
|
|
deployout = repmlp(x)
|
|
|
|
print(((deployout-out)**2).sum()) |