ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/mySwinUnet/tools/testfps.py

96 lines
4.1 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : testfps
@Author : qiqq
@create_time : 2023/11/6 16:40
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
import torch
import numpy as np
from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet
from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
help='Train_Batch size')
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
help='Val_Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str,
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
help='Load model from a .pth file') # 有没有预训练。。
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
help='ignore index defult 100')
parser.add_argument('--origin_shape', action='store_true', default=(512, 512), help='原始输入尺寸')
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss", "IouMiouP"],
help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add_argument('--num_classes', type=int,
default=2, help='output channel of network')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--cfg', type=str,
default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml",
metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
return parser.parse_args()
args = get_args()
config = get_config(args)
model = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes)
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySwinUnet/train_val_test/checkpoints/2023-08-09-00.08/best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
device = torch.device("cuda")
model.to(device)
dummy_input = torch.randn(1, 3, 512, 512,dtype=torch.float).to(device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))
#GPU-WARM-UP
for _ in range(10):
_ = model(dummy_input)
# MEASURE PERFORMANCE
with torch.no_grad():
for rep in range(repetitions):
starter.record()
_ = model(dummy_input)
ender.record()
# WAIT FOR GPU SYNC
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)
mean_fps = 1000. / mean_syn
print(' * Mean@1 {mean_syn:.3f}ms Std@5 {std_syn:.3f}ms FPS@1 {mean_fps:.2f}'.format(mean_syn=mean_syn, std_syn=std_syn, mean_fps=mean_fps))
print(mean_syn)
# * Mean@1 13.584ms Std@5 1.811ms FPS@1 73.61