34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
'''
|
||
计算网络或者model的参数量和FLOPS(貌似是计算量)
|
||
'''
|
||
'''cr:https://blog.csdn.net/qq_35407318/article/details/109359006
|
||
http://t.csdn.cn/prmSk
|
||
|
||
'''
|
||
|
||
import torch
|
||
|
||
from thop import profile
|
||
|
||
'''计算网络的参数量和计算量'''
|
||
|
||
if __name__ == "__main__":
|
||
# 需要使用device来指定网络在GPU还是CPU运行
|
||
from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50
|
||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
model= deeplabv3plus_resnet50(num_classes=2,output_stride=16,pretrained_backbone=False).to(device)
|
||
|
||
# model = baselinemodel_resnet50(modelname="baseline32_ms", num_classes=2, output_stride=16, pretrained_backbone=False).to(device)
|
||
|
||
input = torch.randn(1, 3, 512, 512).to(device)
|
||
flops, params = profile(model, inputs=(input,))
|
||
print("FLOPS:",flops)
|
||
print("params:",params)
|
||
|
||
print("Total FLOPS: %.2fGflops" % (flops/1e9))
|
||
print("Total params: %.2fM" % (params/1e6))
|
||
|
||
# torch.save(model.state_dict(),"test1.pth") |