115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : parts
|
|||
|
@Author : qiqq
|
|||
|
@create_time : 2023/5/30 14:18
|
|||
|
"""
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torch.nn.functional as F
|
|||
|
from torch import nn
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class SpatialGather_Module(nn.Module):
|
|||
|
"""
|
|||
|
|
|||
|
Aggregate the context features according to the initial
|
|||
|
predicted probability distribution.
|
|||
|
Employ the soft-weighted method to aggregate the context.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, cls_num=2, scale=1):
|
|||
|
super(SpatialGather_Module, self).__init__()
|
|||
|
self.cls_num = cls_num
|
|||
|
self.scale = scale
|
|||
|
|
|||
|
def forward(self, feats, probs):
|
|||
|
# 以原始输入512为例 feats 512*64*64 probs num*64*64
|
|||
|
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
|||
|
probs = probs.view(batch_size, c, -1) # 2*1024
|
|||
|
feats = feats.view(batch_size, feats.size(1), -1) # 512*1024
|
|||
|
feats = feats.permute(0, 2, 1) # batch x hw x c #1024*512
|
|||
|
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw #2*1024 #把它转成概率
|
|||
|
ocr_context = torch.matmul(probs, feats) \
|
|||
|
.permute(0, 2, 1).unsqueeze(3) # numclass,1024 * 1024*512--》2*512(这就是那个类别特征)---》batch 512 2 1
|
|||
|
|
|||
|
pv_context=ocr_context[:,:,1,:].unsqueeze(dim=-1) #batch channel h w
|
|||
|
return pv_context # 维度是batch 512 numclass 1
|
|||
|
|
|||
|
# class SFAttention(nn.Module):
|
|||
|
# def __init__(self,feats_channels=256,numclass=2):
|
|||
|
# super(SFAttention, self).__init__()
|
|||
|
# self.project1 = nn.Sequential(
|
|||
|
# nn.Conv2d(256, feats_channels, 1, bias=False),
|
|||
|
# nn.BatchNorm2d(feats_channels),
|
|||
|
# nn.ReLU(inplace=True))
|
|||
|
# self.auxproject = nn.Conv2d(feats_channels, numclass, 1, bias=False)
|
|||
|
# self.getpvf=SpatialGather_Module()
|
|||
|
# self.normalizer = nn.Sigmoid()
|
|||
|
#
|
|||
|
#
|
|||
|
# def forward(self,x,):
|
|||
|
# '''思路:
|
|||
|
# 先弄出一个代表光伏的向量,
|
|||
|
# 然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
|
|||
|
#
|
|||
|
# 这个注意力图与原特征再相乘
|
|||
|
# 关于这个pv的向量怎来:
|
|||
|
# 参考ocr
|
|||
|
# '''
|
|||
|
# inputs = self.project1(x) # --256
|
|||
|
# preds = self.auxproject(inputs) #这个玩意不一定由这个input生成可以由别的
|
|||
|
# pv_feature=self.getpvf(inputs,preds)
|
|||
|
#
|
|||
|
# #
|
|||
|
# relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度,然后形成一个o-1的权重
|
|||
|
#
|
|||
|
# refined_feats = relations * inputs #相当于给每个像素点加权
|
|||
|
#
|
|||
|
# return refined_feats,preds
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# model = SFAttention()
|
|||
|
# #
|
|||
|
# inpoo=torch.rand(8,256,64,64)
|
|||
|
# # pre=torch.rand(8,2,64,64)
|
|||
|
# out=model(inpoo)
|
|||
|
# print(out.shape)
|
|||
|
|
|||
|
class SFAttention(nn.Module):
|
|||
|
def __init__(self,feats_channels=256,numclass=2):
|
|||
|
super(SFAttention, self).__init__()
|
|||
|
self.project1 = nn.Sequential(
|
|||
|
nn.Conv2d(256, feats_channels, 1, bias=False),
|
|||
|
nn.BatchNorm2d(feats_channels),
|
|||
|
nn.ReLU(inplace=True))
|
|||
|
|
|||
|
self.normalizer = nn.Sigmoid()
|
|||
|
|
|||
|
|
|||
|
def forward(self,x,pv_feature):
|
|||
|
'''思路:
|
|||
|
先弄出一个代表光伏的向量,
|
|||
|
然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
|
|||
|
|
|||
|
这个注意力图与原特征再相乘
|
|||
|
关于这个pv的向量怎来:
|
|||
|
参考ocr
|
|||
|
'''
|
|||
|
inputs = self.project1(x) # --256
|
|||
|
|
|||
|
#
|
|||
|
relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度,然后形成一个o-1的权重
|
|||
|
|
|||
|
refined_feats = relations * inputs #相当于给每个像素点加权
|
|||
|
return refined_feats
|
|||
|
|
|||
|
|