12 lines
303 B
Python
12 lines
303 B
Python
from attention.CoordAttention import CoordAtt
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
inp=torch.rand([2, 96, 56, 56])
|
|
inp_dim, oup_dim = 96, 96
|
|
reduction=32
|
|
|
|
coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
|
|
output=coord_attention(inp)
|
|
print(output.shape) |