from attention.CoAtNet import CoAtNet import torch from torch import nn from torch.nn import functional as F input=torch.randn(1,3,224,224) mbconv=CoAtNet(in_ch=3,image_size=224) out=mbconv(input) print(out.shape)