Tan_pytorch_segmentation/pytorch_segmentation/PV_Attention/Polarized-Self-Attention.py

12 lines
302 B
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
from attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(1,512,7,7)
psa = SequentialPolarizedSelfAttention(channel=512)
output=psa(input)
print(output.shape)