34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
|
'''
|
|||
|
计算网络或者model的参数量和FLOPS(貌似是计算量)
|
|||
|
'''
|
|||
|
'''cr:https://blog.csdn.net/qq_35407318/article/details/109359006
|
|||
|
http://t.csdn.cn/prmSk
|
|||
|
|
|||
|
'''
|
|||
|
|
|||
|
import torch
|
|||
|
|
|||
|
from thop import profile
|
|||
|
|
|||
|
'''计算网络的参数量和计算量'''
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 需要使用device来指定网络在GPU还是CPU运行
|
|||
|
from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50
|
|||
|
|
|||
|
|
|||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
|
|||
|
model= deeplabv3plus_resnet50(num_classes=2,output_stride=16,pretrained_backbone=False).to(device)
|
|||
|
|
|||
|
# model = baselinemodel_resnet50(modelname="baseline32_ms", num_classes=2, output_stride=16, pretrained_backbone=False).to(device)
|
|||
|
|
|||
|
input = torch.randn(1, 3, 512, 512).to(device)
|
|||
|
flops, params = profile(model, inputs=(input,))
|
|||
|
print("FLOPS:",flops)
|
|||
|
print("params:",params)
|
|||
|
|
|||
|
print("Total FLOPS: %.2fGflops" % (flops/1e9))
|
|||
|
print("Total params: %.2fM" % (params/1e6))
|
|||
|
|
|||
|
# torch.save(model.state_dict(),"test1.pth")
|