from model.backbone.TnT import TNT import torch from torch import nn if __name__ == '__main__': input=torch.randn(1,3,224,224) model = TNT( img_size=224, patch_size=16, outer_dim=384, inner_dim=24, depth=12, outer_num_heads=6, inner_num_heads=4, qkv_bias=False, inner_stride=4) output=model(input) print(output.shape)