''' 计算网络或者model的参数量和FLOPS(貌似是计算量) ''' '''cr:https://blog.csdn.net/qq_35407318/article/details/109359006 http://t.csdn.cn/prmSk ''' import torch from thop import profile '''计算网络的参数量和计算量''' if __name__ == "__main__": # 需要使用device来指定网络在GPU还是CPU运行 from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model= deeplabv3plus_resnet50(num_classes=2,output_stride=16,pretrained_backbone=False).to(device) # model = baselinemodel_resnet50(modelname="baseline32_ms", num_classes=2, output_stride=16, pretrained_backbone=False).to(device) input = torch.randn(1, 3, 512, 512).to(device) flops, params = profile(model, inputs=(input,)) print("FLOPS:",flops) print("params:",params) print("Total FLOPS: %.2fGflops" % (flops/1e9)) print("Total params: %.2fM" % (params/1e6)) # torch.save(model.state_dict(),"test1.pth")