from model.attention.CrissCrossAttention import CrissCrossAttention
import torch
if __name__ == '__main__':
input = torch.randn(3, 64, 7, 7)
model = CrissCrossAttention(64)
outputs = model(input)
print(outputs.shape)