Tan_pytorch_segmentation/pytorch_segmentation/PV_Backbone/VOLO.py

18 lines
538 B
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
from model.backbone.VOLO import VOLO
import torch
from torch import nn
if __name__ == '__main__':
input=torch.randn(1,3,224,224)
model = VOLO([4, 4, 8, 2],
embed_dims=[192, 384, 384, 384],
num_heads=[6, 12, 12, 12],
mlp_ratios=[3, 3, 3, 3],
downsamples=[True, False, False, False],
outlook_attention=[True, False, False, False ],
post_layers=['ca', 'ca'],
)
output=model(input)
print(output[0].shape)