from attention.CoTAttention import CoTAttention import torch from torch import nn from torch.nn import functional as F input=torch.randn(50,512,7,7) cot = CoTAttention(dim=512,kernel_size=3) output=cot(input) print(output.shape)