ai-station-code/wudingpv/taihuyuan_pv/calculate_parameter_flops.py

34 lines
1.1 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
'''
计算网络或者model的参数量和FLOPS貌似是计算量
'''
'''crhttps://blog.csdn.net/qq_35407318/article/details/109359006
http://t.csdn.cn/prmSk
'''
import torch
from thop import profile
'''计算网络的参数量和计算量'''
if __name__ == "__main__":
# 需要使用device来指定网络在GPU还是CPU运行
from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model= deeplabv3plus_resnet50(num_classes=2,output_stride=16,pretrained_backbone=False).to(device)
# model = baselinemodel_resnet50(modelname="baseline32_ms", num_classes=2, output_stride=16, pretrained_backbone=False).to(device)
input = torch.randn(1, 3, 512, 512).to(device)
flops, params = profile(model, inputs=(input,))
print("FLOPS:",flops)
print("params:",params)
print("Total FLOPS: %.2fGflops" % (flops/1e9))
print("Total params: %.2fM" % (params/1e6))
# torch.save(model.state_dict(),"test1.pth")