Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Axial-attention.py

13 lines
297 B
Python

from model.attention.Axial_attention import AxialImageTransformer
import torch
if __name__ == '__main__':
input = torch.randn(3, 128, 7, 7)
model = AxialImageTransformer(
dim=128,
depth=12,
reversible=True
)
outputs = model(input)
print(outputs.shape)