242 lines
9.0 KiB
Python
242 lines
9.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : utils
|
|
@Author : qiqq
|
|
@create_time : 2023/1/16 9:41
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
class ChannelAttention2(nn.Module):
|
|
def __init__(self, in_planes):
|
|
super(ChannelAttention2, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.cb=nn.Sequential(
|
|
nn.Conv2d(in_planes, in_planes, kernel_size=1,bias=False ),
|
|
nn.BatchNorm2d(in_planes)
|
|
)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
out= self.avg_pool(x)
|
|
out=self.cb(out)
|
|
out=self.sigmoid(out)
|
|
return out
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
def __init__(self, in_planes, ratio=8):
|
|
super(ChannelAttention, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
|
|
# 利用1x1卷积代替全连接
|
|
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
|
|
self.relu1 = nn.ReLU()
|
|
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
|
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
|
out = avg_out + max_out
|
|
return self.sigmoid(out)
|
|
#
|
|
|
|
|
|
|
|
class SpatialAttention(nn.Module):
|
|
def __init__(self, kernel_size=7):
|
|
super(SpatialAttention, self).__init__()
|
|
|
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
|
padding = 3 if kernel_size == 7 else 1
|
|
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
|
x = torch.cat([avg_out, max_out], dim=1)
|
|
x = self.conv1(x)
|
|
return self.sigmoid(x)
|
|
|
|
class cbam_block(nn.Module):
|
|
def __init__(self, channel, ratio=8, kernel_size=7):
|
|
super(cbam_block, self).__init__()
|
|
self.channelattention = ChannelAttention(channel, ratio=ratio)
|
|
# self.channelattention = ChannelAttention2(channel)
|
|
self.spatialattention = SpatialAttention(kernel_size=kernel_size)
|
|
|
|
def forward(self, x):
|
|
x = x * self.channelattention(x)
|
|
x = x * self.spatialattention(x)
|
|
return x
|
|
|
|
|
|
|
|
class CAM_Module(nn.Module):
|
|
""" Channel attention module"""
|
|
def __init__(self, in_dim=256):
|
|
super(CAM_Module, self).__init__()
|
|
self.chanel_in = in_dim
|
|
|
|
|
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|
self.softmax =nn.Softmax(dim=-1)
|
|
def forward(self,x):
|
|
"""
|
|
inputs :
|
|
x : input feature maps( B X C X H X W)
|
|
returns :
|
|
out : attention value + input feature
|
|
attention: B X C X C
|
|
"""
|
|
m_batchsize, C, height, width = x.size()
|
|
proj_query = x.view(m_batchsize, C, -1)
|
|
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
|
|
energy = torch.bmm(proj_query, proj_key)
|
|
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
|
|
attention = self.softmax(energy_new)
|
|
proj_value = x.view(m_batchsize, C, -1)
|
|
|
|
out = torch.bmm(attention, proj_value)
|
|
out = out.view(m_batchsize, C, height, width)
|
|
|
|
out = self.gamma*out + x
|
|
return out
|
|
|
|
|
|
|
|
class cbamselfcam(nn.Module):
|
|
def __init__(self, channel, ratio=8, kernel_size=7):
|
|
super(cbamselfcam, self).__init__()
|
|
self.spatialattention = SpatialAttention(kernel_size=kernel_size)
|
|
self.cam=CAM_Module(channel)
|
|
def forward(self, x):
|
|
x = x * self.spatialattention(x)
|
|
x = self.cam(x)
|
|
return x
|
|
|
|
class _SimpleSegmentationModel(nn.Module):
|
|
def __init__(self, backbone, classifier):
|
|
super(_SimpleSegmentationModel, self).__init__()
|
|
self.backbone = backbone
|
|
self.classifier = classifier
|
|
|
|
def forward(self, x):
|
|
input_shape = x.shape[-2:]
|
|
features = self.backbone(x)
|
|
|
|
x,aux = self.classifier(features)
|
|
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
|
|
aux = F.interpolate(aux, size=input_shape, mode='bilinear', align_corners=False)
|
|
return x,aux
|
|
|
|
# x = self.classifier(features)
|
|
# x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
|
|
# return x
|
|
|
|
|
|
# #原始版本deeplabv3+
|
|
# x = self.classifier(features)
|
|
# x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
|
|
# return x
|
|
|
|
#加强版本edecoder
|
|
# output_feature, aux0, aux1 = self.classifier(features)
|
|
# output_feature = F.interpolate(output_feature, size=input_shape, mode='bilinear', align_corners=False)
|
|
# aux0 = F.interpolate(aux0, size=input_shape, mode='bilinear', align_corners=False)
|
|
# aux1 = F.interpolate(aux1, size=input_shape, mode='bilinear', align_corners=False)
|
|
# return output_feature, aux0, aux1
|
|
|
|
|
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict):
|
|
"""
|
|
Module wrapper that returns intermediate layers from a model
|
|
|
|
It has a strong assumption that the modules have been registered
|
|
into the model in the same order as they are used.
|
|
This means that one should **not** reuse the same nn.Module
|
|
twice in the forward if you want this to work.
|
|
|
|
Additionally, it is only able to query submodules that are directly
|
|
assigned to the model. So if `model` is passed, `model.feature1` can
|
|
be returned, but not `model.feature1.layer2`.
|
|
|
|
Arguments:
|
|
model (nn.Module): model on which we will extract the features
|
|
return_layers (Dict[name, new_name]): a dict containing the names
|
|
of the modules for which the activations will be returned as
|
|
the key of the dict, and the value of the dict is the name
|
|
of the returned activation (which the user can specify).
|
|
|
|
Examples::
|
|
|
|
>>> m = torchvision.models.resnet18(pretrained=True)
|
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
|
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
|
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
|
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
|
>>> print([(k, v.shape) for k, v in out.items()])
|
|
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
|
"""
|
|
|
|
def __init__(self, model, return_layers, hrnet_flag=False):
|
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
|
raise ValueError("return_layers are not present in model")
|
|
|
|
self.hrnet_flag = hrnet_flag
|
|
|
|
orig_return_layers = return_layers
|
|
return_layers = {k: v for k, v in return_layers.items()}
|
|
layers = OrderedDict()
|
|
for name, module in model.named_children():
|
|
layers[name] = module
|
|
if name in return_layers:
|
|
del return_layers[name]
|
|
if not return_layers:
|
|
break
|
|
|
|
super(IntermediateLayerGetter, self).__init__(layers)
|
|
self.return_layers = orig_return_layers
|
|
|
|
def forward(self, x):
|
|
out = OrderedDict()
|
|
for name, module in self.named_children():
|
|
if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
|
|
if name == 'transition1': # in transition1, you need to split the module to two streams first
|
|
x = [trans(x) for trans in module]
|
|
else: # all other transition is just an extra one stream split
|
|
x.append(module(x[-1]))
|
|
else: # other models (ex:resnet,mobilenet) are convolutions in series.
|
|
x = module(x)
|
|
|
|
if name in self.return_layers:
|
|
out_name = self.return_layers[name]
|
|
if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
|
|
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
|
|
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
|
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
|
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
|
x = torch.cat([x[0], x1, x2, x3], dim=1)
|
|
out[out_name] = x
|
|
else:
|
|
out[out_name] = x
|
|
return out
|
|
# ______________ 胡改的cbam
|
|
|
|
|