#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : parts @Author : qiqq @create_time : 2023/4/14 22:10 """ import torch from torch import nn import numpy as np import torch.nn.functional as F class SpatialGather_Module(nn.Module): """ 前景特征的生成(光伏生成一个 长为n的向量 代表光伏) """ def __init__(self, cls_num=2, scale=1): super(SpatialGather_Module, self).__init__() self.cls_num = cls_num self.scale = scale def forward(self, feats, probs): # 以原始输入64*64*512为例 feats 512*64*64 probs 2*64*64 batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) probs = probs.view(batch_size, c, -1) # feats = feats.view(batch_size, feats.size(1), -1) # feats = feats.permute(0, 2, 1) # probs = F.softmax(self.scale * probs, dim=2) ocr_context = torch.matmul(probs, feats) \ .permute(0, 2, 1).unsqueeze(3) # return ocr_context # 维度是batch numcahnnel numclass 1 class ForegroundRelation(nn.Module): def __init__(self, in_channels, out_channels, ): super(ForegroundRelation, self).__init__() self.foreground_encoder = nn.Sequential( nn.Conv2d(2048, out_channels, 1), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 1), ) self.content_encoder = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels), nn.ReLU(True) ) self.feature_reencoders = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels), nn.ReLU(True) ) self.normalizer = nn.Sigmoid() def forward(self, foreground_feature, feature): ''' scene_feature全局池化过的1*1*c的场景向量可换成class ''' content_feats = self.content_encoder(feature) forground_feat = self.foreground_encoder(foreground_feature) relations = self.normalizer((forground_feat * content_feats).sum(dim=1, keepdim=True)) p_feats = self.feature_reencoders(feature) refined_feats = relations*p_feats return refined_feats feats =torch.rand(1,256,64,64) probs =torch.rand(1,2,64,64) forground_feats = SpatialGather_Module() # out=