from attention.ShuffleAttention import ShuffleAttention import torch from torch import nn from torch.nn import functional as F input = torch.randn(50, 512, 7, 7) se = ShuffleAttention(channel=512, G=8) output = se(input) print(output.shape)