from attention.SGE import SpatialGroupEnhance import torch from torch import nn from torch.nn import functional as F input=torch.randn(50,512,7,7) sge = SpatialGroupEnhance(groups=8) output=sge(input) print(output.shape)