19 lines
402 B
Python
19 lines
402 B
Python
|
|
||
|
from model.backbone.PIT import PoolingTransformer
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
input=torch.randn(1,3,224,224)
|
||
|
model = PoolingTransformer(
|
||
|
image_size=224,
|
||
|
patch_size=14,
|
||
|
stride=7,
|
||
|
base_dims=[64, 64, 64],
|
||
|
depth=[3, 6, 4],
|
||
|
heads=[4, 8, 16],
|
||
|
mlp_ratio=4
|
||
|
)
|
||
|
output=model(input)
|
||
|
print(output.shape)
|