from attention.gfnet import GFNet import torch from torch import nn from torch.nn import functional as F x = torch.randn(1, 3, 224, 224) gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000) out = gfnet(x) print(out.shape)