12 lines
303 B
Python
12 lines
303 B
Python
|
from attention.CoordAttention import CoordAtt
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
inp=torch.rand([2, 96, 56, 56])
|
||
|
inp_dim, oup_dim = 96, 96
|
||
|
reduction=32
|
||
|
|
||
|
coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
|
||
|
output=coord_attention(inp)
|
||
|
print(output.shape)
|