from attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention import torch from torch import nn from torch.nn import functional as F input=torch.randn(1,512,7,7) psa = SequentialPolarizedSelfAttention(channel=512) output=psa(input) print(output.shape)