#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : modeltest @Author : qiqq @create_time : 2023/6/6 14:50 """ import os import torch import argparse from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config parser = argparse.ArgumentParser() parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input') parser.add_argument('--cfg', type=str, default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) args = parser.parse_args() config = get_config(args) net = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes) net.load_from(config) print("权重加载") inpp=torch.randn(2,3,512,512) out =net(inpp) print(type(out)) print(out.shape)