ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/parts.py

115 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : parts
@Author : qiqq
@create_time : 2023/5/30 14:18
"""
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=2, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
# 以原始输入512为例 feats 512*64*64 probs num*64*64
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
probs = probs.view(batch_size, c, -1) # 2*1024
feats = feats.view(batch_size, feats.size(1), -1) # 512*1024
feats = feats.permute(0, 2, 1) # batch x hw x c #1024*512
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw #2*1024 #把它转成概率
ocr_context = torch.matmul(probs, feats) \
.permute(0, 2, 1).unsqueeze(3) # numclass1024 * 1024*512--》2*512这就是那个类别特征---》batch 512 2 1
pv_context=ocr_context[:,:,1,:].unsqueeze(dim=-1) #batch channel h w
return pv_context # 维度是batch 512 numclass 1
# class SFAttention(nn.Module):
# def __init__(self,feats_channels=256,numclass=2):
# super(SFAttention, self).__init__()
# self.project1 = nn.Sequential(
# nn.Conv2d(256, feats_channels, 1, bias=False),
# nn.BatchNorm2d(feats_channels),
# nn.ReLU(inplace=True))
# self.auxproject = nn.Conv2d(feats_channels, numclass, 1, bias=False)
# self.getpvf=SpatialGather_Module()
# self.normalizer = nn.Sigmoid()
#
#
# def forward(self,x,):
# '''思路:
# 先弄出一个代表光伏的向量,
# 然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
#
# 这个注意力图与原特征再相乘
# 关于这个pv的向量怎来
# 参考ocr
# '''
# inputs = self.project1(x) # --256
# preds = self.auxproject(inputs) #这个玩意不一定由这个input生成可以由别的
# pv_feature=self.getpvf(inputs,preds)
#
# #
# relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度然后形成一个o-1的权重
#
# refined_feats = relations * inputs #相当于给每个像素点加权
#
# return refined_feats,preds
# model = SFAttention()
# #
# inpoo=torch.rand(8,256,64,64)
# # pre=torch.rand(8,2,64,64)
# out=model(inpoo)
# print(out.shape)
class SFAttention(nn.Module):
def __init__(self,feats_channels=256,numclass=2):
super(SFAttention, self).__init__()
self.project1 = nn.Sequential(
nn.Conv2d(256, feats_channels, 1, bias=False),
nn.BatchNorm2d(feats_channels),
nn.ReLU(inplace=True))
self.normalizer = nn.Sigmoid()
def forward(self,x,pv_feature):
'''思路:
先弄出一个代表光伏的向量,
然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
这个注意力图与原特征再相乘
关于这个pv的向量怎来
参考ocr
'''
inputs = self.project1(x) # --256
#
relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度然后形成一个o-1的权重
refined_feats = relations * inputs #相当于给每个像素点加权
return refined_feats