from model.attention.DAT import DAT import torch if __name__ == '__main__': input = torch.randn(1, 3, 224, 224) model = DAT( img_size=224, patch_size=4, num_classes=1000, expansion=4, dim_stem=96, dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']], heads=[3, 6, 12, 24], window_sizes=[7, 7, 7, 7], groups=[-1, -1, 3, 6], use_pes=[False, False, True, True], dwc_pes=[False, False, False, False], strides=[-1, -1, 1, 1], sr_ratios=[-1, -1, -1, -1], offset_range_factor=[-1, -1, 2, 2], no_offs=[False, False, False, False], fixed_pes=[False, False, False, False], use_dwc_mlps=[False, False, False, False], use_conv_patches=False, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.2, ) output = model(input) print(output[0].shape)