105 lines
4.4 KiB
Python
105 lines
4.4 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@project:
|
||
|
@File : utils
|
||
|
@Author : qiqq
|
||
|
@create_time : 2023/1/16 9:41
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import numpy as np
|
||
|
import torch.nn.functional as F
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
|
||
|
class _SimpleSegmentationModel(nn.Module):
|
||
|
def __init__(self, backbone, classifier):
|
||
|
super(_SimpleSegmentationModel, self).__init__()
|
||
|
self.backbone = backbone
|
||
|
self.classifier = classifier
|
||
|
|
||
|
def forward(self, x):
|
||
|
input_shape = x.shape[-2:]
|
||
|
features = self.backbone(x)
|
||
|
x = self.classifier(features)
|
||
|
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class IntermediateLayerGetter(nn.ModuleDict):
|
||
|
"""
|
||
|
Module wrapper that returns intermediate layers from a model
|
||
|
|
||
|
It has a strong assumption that the modules have been registered
|
||
|
into the model in the same order as they are used.
|
||
|
This means that one should **not** reuse the same nn.Module
|
||
|
twice in the forward if you want this to work.
|
||
|
|
||
|
Additionally, it is only able to query submodules that are directly
|
||
|
assigned to the model. So if `model` is passed, `model.feature1` can
|
||
|
be returned, but not `model.feature1.layer2`.
|
||
|
|
||
|
Arguments:
|
||
|
model (nn.Module): model on which we will extract the features
|
||
|
return_layers (Dict[name, new_name]): a dict containing the names
|
||
|
of the modules for which the activations will be returned as
|
||
|
the key of the dict, and the value of the dict is the name
|
||
|
of the returned activation (which the user can specify).
|
||
|
|
||
|
Examples::
|
||
|
|
||
|
>>> m = torchvision.models.resnet18(pretrained=True)
|
||
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
||
|
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
||
|
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
||
|
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
||
|
>>> print([(k, v.shape) for k, v in out.items()])
|
||
|
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
||
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model, return_layers, hrnet_flag=False):
|
||
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
||
|
raise ValueError("return_layers are not present in model")
|
||
|
|
||
|
self.hrnet_flag = hrnet_flag
|
||
|
|
||
|
orig_return_layers = return_layers
|
||
|
return_layers = {k: v for k, v in return_layers.items()}
|
||
|
layers = OrderedDict()
|
||
|
for name, module in model.named_children():
|
||
|
layers[name] = module
|
||
|
if name in return_layers:
|
||
|
del return_layers[name]
|
||
|
if not return_layers:
|
||
|
break
|
||
|
|
||
|
super(IntermediateLayerGetter, self).__init__(layers)
|
||
|
self.return_layers = orig_return_layers
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = OrderedDict()
|
||
|
for name, module in self.named_children():
|
||
|
if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
|
||
|
if name == 'transition1': # in transition1, you need to split the module to two streams first
|
||
|
x = [trans(x) for trans in module]
|
||
|
else: # all other transition is just an extra one stream split
|
||
|
x.append(module(x[-1]))
|
||
|
else: # other models (ex:resnet,mobilenet) are convolutions in series.
|
||
|
x = module(x)
|
||
|
|
||
|
if name in self.return_layers:
|
||
|
out_name = self.return_layers[name]
|
||
|
if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
|
||
|
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
|
||
|
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
||
|
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
||
|
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
||
|
x = torch.cat([x[0], x1, x2, x3], dim=1)
|
||
|
out[out_name] = x
|
||
|
else:
|
||
|
out[out_name] = x
|
||
|
return out
|