ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/mySwinUnet/model/modeltest.py

45 lines
1.3 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : modeltest
@Author : qiqq
@create_time : 2023/6/6 14:50
"""
import os
import torch
import argparse
from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet
from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int,
default=2, help='output channel of network')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--cfg', type=str, default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
args = parser.parse_args()
config = get_config(args)
net = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes)
net.load_from(config)
print("权重加载")
inpp=torch.randn(2,3,512,512)
out =net(inpp)
print(type(out))
print(out.shape)