Tan_pytorch_segmentation/pytorch_segmentation/PV_Backbone/DViT.py

19 lines
480 B
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
from model.backbone.DViT import DeepVisionTransformer
import torch
from torch import nn
if __name__ == '__main__':
input=torch.randn(1,3,224,224)
model = DeepVisionTransformer(
patch_size=16, embed_dim=384,
depth=[False] * 16,
apply_transform=[False] * 0 + [True] * 32,
num_heads=12,
mlp_ratio=3,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
output=model(input)
print(output.shape)