ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/parts.py

115 lines
3.9 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : parts
@Author : qiqq
@create_time : 2023/5/30 14:18
"""
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=2, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
# 以原始输入512为例 feats 512*64*64 probs num*64*64
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
probs = probs.view(batch_size, c, -1) # 2*1024
feats = feats.view(batch_size, feats.size(1), -1) # 512*1024
feats = feats.permute(0, 2, 1) # batch x hw x c #1024*512
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw #2*1024 #把它转成概率
ocr_context = torch.matmul(probs, feats) \
.permute(0, 2, 1).unsqueeze(3) # numclass1024 * 1024*512--》2*512这就是那个类别特征---》batch 512 2 1
pv_context=ocr_context[:,:,1,:].unsqueeze(dim=-1) #batch channel h w
return pv_context # 维度是batch 512 numclass 1
# class SFAttention(nn.Module):
# def __init__(self,feats_channels=256,numclass=2):
# super(SFAttention, self).__init__()
# self.project1 = nn.Sequential(
# nn.Conv2d(256, feats_channels, 1, bias=False),
# nn.BatchNorm2d(feats_channels),
# nn.ReLU(inplace=True))
# self.auxproject = nn.Conv2d(feats_channels, numclass, 1, bias=False)
# self.getpvf=SpatialGather_Module()
# self.normalizer = nn.Sigmoid()
#
#
# def forward(self,x,):
# '''思路:
# 先弄出一个代表光伏的向量,
# 然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
#
# 这个注意力图与原特征再相乘
# 关于这个pv的向量怎来
# 参考ocr
# '''
# inputs = self.project1(x) # --256
# preds = self.auxproject(inputs) #这个玩意不一定由这个input生成可以由别的
# pv_feature=self.getpvf(inputs,preds)
#
# #
# relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度然后形成一个o-1的权重
#
# refined_feats = relations * inputs #相当于给每个像素点加权
#
# return refined_feats,preds
# model = SFAttention()
# #
# inpoo=torch.rand(8,256,64,64)
# # pre=torch.rand(8,2,64,64)
# out=model(inpoo)
# print(out.shape)
class SFAttention(nn.Module):
def __init__(self,feats_channels=256,numclass=2):
super(SFAttention, self).__init__()
self.project1 = nn.Sequential(
nn.Conv2d(256, feats_channels, 1, bias=False),
nn.BatchNorm2d(feats_channels),
nn.ReLU(inplace=True))
self.normalizer = nn.Sigmoid()
def forward(self,x,pv_feature):
'''思路:
先弄出一个代表光伏的向量
然后这个向量和原始的特征图的每个点做内积然后经过sigmode形成一个注意力图
这个注意力图与原特征再相乘
关于这个pv的向量怎来
参考ocr
'''
inputs = self.project1(x) # --256
#
relations = self.normalizer((pv_feature * inputs).sum(dim=1, keepdim=True)) #每个像素点和pvpv_feature算一个相似度然后形成一个o-1的权重
refined_feats = relations * inputs #相当于给每个像素点加权
return refined_feats