18 lines
538 B
Python
18 lines
538 B
Python
|
|
from model.backbone.VOLO import VOLO
|
|
import torch
|
|
from torch import nn
|
|
|
|
if __name__ == '__main__':
|
|
input=torch.randn(1,3,224,224)
|
|
model = VOLO([4, 4, 8, 2],
|
|
embed_dims=[192, 384, 384, 384],
|
|
num_heads=[6, 12, 12, 12],
|
|
mlp_ratios=[3, 3, 3, 3],
|
|
downsamples=[True, False, False, False],
|
|
outlook_attention=[True, False, False, False ],
|
|
post_layers=['ca', 'ca'],
|
|
)
|
|
output=model(input)
|
|
print(output[0].shape)
|