#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : modeling @Author : qiqq @create_time : 2023/1/17 9:00 """ from taihuyuan_pv.compared_experiment.imdeeplab3p.model.utils import IntermediateLayerGetter from taihuyuan_pv.compared_experiment.imdeeplab3p.model._deeplab import DeepLabHeadV3Plus, DeepLabV3,IMDeepLabHeadV3Plus import torch from taihuyuan_pv.compared_experiment.imdeeplab3p.model import resnet #主要的 def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): if output_stride == 8: replace_stride_with_dilation = [False, True, True] aspp_dilate = [12, 24, 36] # aspp_dilate = [5, 11, 17] # aspp_dilate = [3,6,12,24] # # aspp_dilate = [12,18,24] #第一个消融# else: replace_stride_with_dilation = [False, False, True] aspp_dilate = [6, 12, 18] backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=replace_stride_with_dilation) inplanes = 2048 low_level_planes = 256 if name == 'imdeeplabv3plus': print("启用imdeeplabv3plus") return_layers = {'layer4': 'out', 'layer1': 'low_level1','relu111': 'low_level2',} classifier = IMDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) elif name == 'deeplabv3plus': return_layers = {'layer4': 'out', 'layer1': 'low_level'} classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) model = DeepLabV3(backbone, classifier) return model def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): if backbone.startswith('resnet'): model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) else: raise NotImplementedError return model # Deeplab v3+ def imdeeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True): """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: num_classes (int): number of classes. output_stride (int): output stride for deeplab. pretrained_backbone (bool): If True, use the pretrained backbone. """ return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True): """Constructs a DeepLabV3+ model with a ResNet-101 backbone. Args: num_classes (int): number of classes. output_stride (int): output stride for deeplab. pretrained_backbone (bool): If True, use the pretrained backbone. """ return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) if __name__ == '__main__': model =imdeeplabv3plus_resnet50(num_classes=2,output_stride=8,pretrained_backbone=False) inputt= torch.rand(2,3,512,512) # model.eval() out= model(inputt) # print(type(out)) print(out.shape)