96 lines
4.1 KiB
Python
96 lines
4.1 KiB
Python
|
|
||
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@project:
|
||
|
@File : testfps
|
||
|
@Author : qiqq
|
||
|
@create_time : 2023/11/6 16:40
|
||
|
"""
|
||
|
|
||
|
import os
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
import argparse
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet
|
||
|
from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config
|
||
|
|
||
|
|
||
|
def get_args():
|
||
|
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
|
||
|
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
|
||
|
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
|
||
|
help='Train_Batch size')
|
||
|
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
|
||
|
help='Val_Batch size')
|
||
|
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
|
||
|
help='Learning rate', dest='lr')
|
||
|
parser.add_argument('--load', '-f', type=str,
|
||
|
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
|
||
|
help='Load model from a .pth file') # 有没有预训练。。
|
||
|
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
|
||
|
help='ignore index defult 100')
|
||
|
parser.add_argument('--origin_shape', action='store_true', default=(512, 512), help='原始输入尺寸')
|
||
|
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
|
||
|
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
|
||
|
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss", "IouMiouP"],
|
||
|
help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
|
||
|
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
|
||
|
parser.add_argument('--num_classes', type=int,
|
||
|
default=2, help='output channel of network')
|
||
|
|
||
|
parser.add_argument('--img_size', type=int,
|
||
|
default=512, help='input patch size of network input')
|
||
|
|
||
|
parser.add_argument('--cfg', type=str,
|
||
|
default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml",
|
||
|
metavar="FILE", help='path to config file', )
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--opts",
|
||
|
help="Modify config options by adding 'KEY VALUE' pairs. ",
|
||
|
default=None,
|
||
|
nargs='+',
|
||
|
)
|
||
|
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
args = get_args()
|
||
|
config = get_config(args)
|
||
|
model = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes)
|
||
|
|
||
|
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySwinUnet/train_val_test/checkpoints/2023-08-09-00.08/best.pth"
|
||
|
model_dict = torch.load(model_path)
|
||
|
model.load_state_dict(model_dict['net'])
|
||
|
print("权重加载")
|
||
|
model.eval()
|
||
|
|
||
|
device = torch.device("cuda")
|
||
|
model.to(device)
|
||
|
dummy_input = torch.randn(1, 3, 512, 512,dtype=torch.float).to(device)
|
||
|
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||
|
repetitions = 300
|
||
|
timings=np.zeros((repetitions,1))
|
||
|
#GPU-WARM-UP
|
||
|
for _ in range(10):
|
||
|
_ = model(dummy_input)
|
||
|
# MEASURE PERFORMANCE
|
||
|
with torch.no_grad():
|
||
|
for rep in range(repetitions):
|
||
|
starter.record()
|
||
|
_ = model(dummy_input)
|
||
|
ender.record()
|
||
|
# WAIT FOR GPU SYNC
|
||
|
torch.cuda.synchronize()
|
||
|
curr_time = starter.elapsed_time(ender)
|
||
|
timings[rep] = curr_time
|
||
|
mean_syn = np.sum(timings) / repetitions
|
||
|
std_syn = np.std(timings)
|
||
|
mean_fps = 1000. / mean_syn
|
||
|
print(' * Mean@1 {mean_syn:.3f}ms Std@5 {std_syn:.3f}ms FPS@1 {mean_fps:.2f}'.format(mean_syn=mean_syn, std_syn=std_syn, mean_fps=mean_fps))
|
||
|
print(mean_syn)
|
||
|
|
||
|
# * Mean@1 13.584ms Std@5 1.811ms FPS@1 73.61
|