|
from model.attention.Axial_attention import AxialImageTransformer
|
|
import torch
|
|
|
|
if __name__ == '__main__':
|
|
input = torch.randn(3, 128, 7, 7)
|
|
model = AxialImageTransformer(
|
|
dim=128,
|
|
depth=12,
|
|
reversible=True
|
|
)
|
|
outputs = model(input)
|
|
print(outputs.shape)
|