Tan_pytorch_segmentation/pytorch_segmentation/PV_Backbone/LeViT.py

13 lines
308 B
Python

from model.backbone.LeViT import *
import torch
from torch import nn
if __name__ == '__main__':
for name in specification:
input=torch.randn(1,3,224,224)
model = globals()[name](fuse=True, pretrained=False)
model.eval()
output = model(input)
print(output.shape)