Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/HaloNet-Attention.py

13 lines
257 B
Python

from attention.HaloAttention import HaloAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(1,512,8,8)
halo = HaloAttention(dim=512,
block_size=2,
halo_size=1,)
output=halo(input)
print(output.shape)