11 lines
294 B
Python
11 lines
294 B
Python
|
from attention.MobileViTAttention import MobileViTAttention
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
m = MobileViTAttention()
|
||
|
input = torch.randn(1, 3, 49, 49)
|
||
|
output = m(input)
|
||
|
print(output.shape) # output:(1,3,49,49)
|