99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@project:
|
||
|
@File : modeling
|
||
|
@Author : qiqq
|
||
|
@create_time : 2023/1/17 9:00
|
||
|
"""
|
||
|
|
||
|
from taihuyuan_pv.compared_experiment.imdeeplab3p.model.utils import IntermediateLayerGetter
|
||
|
from taihuyuan_pv.compared_experiment.imdeeplab3p.model._deeplab import DeepLabHeadV3Plus, DeepLabV3,IMDeepLabHeadV3Plus
|
||
|
import torch
|
||
|
|
||
|
from taihuyuan_pv.compared_experiment.imdeeplab3p.model import resnet
|
||
|
|
||
|
|
||
|
#主要的
|
||
|
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
|
||
|
if output_stride == 8:
|
||
|
replace_stride_with_dilation = [False, True, True]
|
||
|
aspp_dilate = [12, 24, 36]
|
||
|
# aspp_dilate = [5, 11, 17]
|
||
|
# aspp_dilate = [3,6,12,24] #
|
||
|
# aspp_dilate = [12,18,24] #第一个消融#
|
||
|
else:
|
||
|
replace_stride_with_dilation = [False, False, True]
|
||
|
aspp_dilate = [6, 12, 18]
|
||
|
|
||
|
backbone = resnet.__dict__[backbone_name](
|
||
|
pretrained=pretrained_backbone,
|
||
|
replace_stride_with_dilation=replace_stride_with_dilation)
|
||
|
|
||
|
inplanes = 2048
|
||
|
low_level_planes = 256
|
||
|
|
||
|
if name == 'imdeeplabv3plus':
|
||
|
print("启用imdeeplabv3plus")
|
||
|
|
||
|
return_layers = {'layer4': 'out', 'layer1': 'low_level1','relu111': 'low_level2',}
|
||
|
classifier = IMDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
||
|
elif name == 'deeplabv3plus':
|
||
|
return_layers = {'layer4': 'out', 'layer1': 'low_level'}
|
||
|
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
||
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||
|
|
||
|
model = DeepLabV3(backbone, classifier)
|
||
|
return model
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
|
||
|
|
||
|
if backbone.startswith('resnet'):
|
||
|
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride,
|
||
|
pretrained_backbone=pretrained_backbone)
|
||
|
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
return model
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
# Deeplab v3+
|
||
|
|
||
|
def imdeeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
|
||
|
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
|
||
|
|
||
|
Args:
|
||
|
num_classes (int): number of classes.
|
||
|
output_stride (int): output stride for deeplab.
|
||
|
pretrained_backbone (bool): If True, use the pretrained backbone.
|
||
|
"""
|
||
|
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride,
|
||
|
pretrained_backbone=pretrained_backbone)
|
||
|
|
||
|
|
||
|
def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
|
||
|
"""Constructs a DeepLabV3+ model with a ResNet-101 backbone.
|
||
|
|
||
|
Args:
|
||
|
num_classes (int): number of classes.
|
||
|
output_stride (int): output stride for deeplab.
|
||
|
pretrained_backbone (bool): If True, use the pretrained backbone.
|
||
|
"""
|
||
|
return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride,
|
||
|
pretrained_backbone=pretrained_backbone)
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
model =imdeeplabv3plus_resnet50(num_classes=2,output_stride=8,pretrained_backbone=False)
|
||
|
inputt= torch.rand(2,3,512,512)
|
||
|
# model.eval()
|
||
|
out= model(inputt)
|
||
|
# print(type(out))
|
||
|
print(out.shape)
|