from model.backbone.HATNet import HATNet import torch from torch import nn from torch.nn import functional as F if __name__ == '__main__': input=torch.randn(1,3,224,224) hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4], grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3]) output=hat(input) print(output.shape)