33 lines
1011 B
Python
33 lines
1011 B
Python
from model.attention.DAT import DAT
|
|
import torch
|
|
|
|
if __name__ == '__main__':
|
|
input = torch.randn(1, 3, 224, 224)
|
|
model = DAT(
|
|
img_size=224,
|
|
patch_size=4,
|
|
num_classes=1000,
|
|
expansion=4,
|
|
dim_stem=96,
|
|
dims=[96, 192, 384, 768],
|
|
depths=[2, 2, 6, 2],
|
|
stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],
|
|
heads=[3, 6, 12, 24],
|
|
window_sizes=[7, 7, 7, 7],
|
|
groups=[-1, -1, 3, 6],
|
|
use_pes=[False, False, True, True],
|
|
dwc_pes=[False, False, False, False],
|
|
strides=[-1, -1, 1, 1],
|
|
sr_ratios=[-1, -1, -1, -1],
|
|
offset_range_factor=[-1, -1, 2, 2],
|
|
no_offs=[False, False, False, False],
|
|
fixed_pes=[False, False, False, False],
|
|
use_dwc_mlps=[False, False, False, False],
|
|
use_conv_patches=False,
|
|
drop_rate=0.0,
|
|
attn_drop_rate=0.0,
|
|
drop_path_rate=0.2,
|
|
)
|
|
output = model(input)
|
|
print(output[0].shape)
|