from attention.ViP import WeightedPermuteMLP import torch from torch import nn from torch.nn import functional as F input=torch.randn(64,8,8,512) seg_dim=8 vip=WeightedPermuteMLP(512,seg_dim) out=vip(input) print(out.shape)