45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : modeltest
|
|
@Author : qiqq
|
|
@create_time : 2023/6/6 14:50
|
|
"""
|
|
import os
|
|
import torch
|
|
import argparse
|
|
from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet
|
|
from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--num_classes', type=int,
|
|
default=2, help='output channel of network')
|
|
|
|
|
|
parser.add_argument('--img_size', type=int,
|
|
default=512, help='input patch size of network input')
|
|
|
|
parser.add_argument('--cfg', type=str, default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', )
|
|
|
|
parser.add_argument(
|
|
"--opts",
|
|
help="Modify config options by adding 'KEY VALUE' pairs. ",
|
|
default=None,
|
|
nargs='+',
|
|
)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
config = get_config(args)
|
|
|
|
net = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes)
|
|
net.load_from(config)
|
|
print("权重加载")
|
|
|
|
inpp=torch.randn(2,3,512,512)
|
|
out =net(inpp)
|
|
print(type(out))
|
|
print(out.shape) |