Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/CrissCrossAttention.py

9 lines
234 B
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
from model.attention.CrissCrossAttention import CrissCrossAttention
import torch
if __name__ == '__main__':
input = torch.randn(3, 64, 7, 7)
model = CrissCrossAttention(64)
outputs = model(input)
print(outputs.shape)