93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : parts
|
|
@Author : qiqq
|
|
@create_time : 2023/4/14 22:10
|
|
"""
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
import numpy as np
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class SpatialGather_Module(nn.Module):
|
|
"""
|
|
前景特征的生成(光伏生成一个 长为n的向量 代表光伏)
|
|
"""
|
|
|
|
def __init__(self, cls_num=2, scale=1):
|
|
super(SpatialGather_Module, self).__init__()
|
|
self.cls_num = cls_num
|
|
self.scale = scale
|
|
|
|
def forward(self, feats, probs):
|
|
# 以原始输入64*64*512为例 feats 512*64*64 probs 2*64*64
|
|
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
|
probs = probs.view(batch_size, c, -1) #
|
|
feats = feats.view(batch_size, feats.size(1), -1) #
|
|
feats = feats.permute(0, 2, 1) #
|
|
probs = F.softmax(self.scale * probs, dim=2)
|
|
ocr_context = torch.matmul(probs, feats) \
|
|
.permute(0, 2, 1).unsqueeze(3) #
|
|
|
|
return ocr_context # 维度是batch numcahnnel numclass 1
|
|
|
|
class ForegroundRelation(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
):
|
|
super(ForegroundRelation, self).__init__()
|
|
|
|
|
|
self.foreground_encoder = nn.Sequential(
|
|
nn.Conv2d(2048, out_channels, 1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(out_channels, out_channels, 1),
|
|
)
|
|
self.content_encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True)
|
|
)
|
|
self.feature_reencoders = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True)
|
|
)
|
|
|
|
self.normalizer = nn.Sigmoid()
|
|
|
|
def forward(self, foreground_feature, feature):
|
|
'''
|
|
|
|
scene_feature全局池化过的1*1*c的场景向量可换成class
|
|
|
|
'''
|
|
content_feats = self.content_encoder(feature)
|
|
|
|
forground_feat = self.foreground_encoder(foreground_feature)
|
|
relations = self.normalizer((forground_feat * content_feats).sum(dim=1, keepdim=True))
|
|
|
|
p_feats = self.feature_reencoders(feature)
|
|
|
|
refined_feats = relations*p_feats
|
|
|
|
return refined_feats
|
|
|
|
|
|
|
|
|
|
|
|
feats =torch.rand(1,256,64,64)
|
|
probs =torch.rand(1,2,64,64)
|
|
|
|
|
|
forground_feats = SpatialGather_Module()
|
|
# out= |