|
from attention.MobileViTAttention import MobileViTAttention
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
if __name__ == '__main__':
|
|
m = MobileViTAttention()
|
|
input = torch.randn(1, 3, 49, 49)
|
|
output = m(input)
|
|
print(output.shape) # output:(1,3,49,49)
|