335 lines
15 KiB
Python
335 lines
15 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : losses
|
||
@Author : qiqq
|
||
@create_time : 2023/1/3 14:37
|
||
"""
|
||
|
||
import math
|
||
|
||
import os
|
||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
####farseg里边的针对不平衡提出来的loss
|
||
|
||
#实际上既是在一般的交叉熵前乘了一个权重因子
|
||
def softmax_focalloss(y_pred, y_true, ignore_index=255, gamma=2.0,alpha=0.25,eps = 1e-8, normalize=False):
|
||
"""
|
||
|
||
Args:
|
||
y_pred: [N, #class, H, W]
|
||
y_true: [N, H, W] from 0 to #class
|
||
gamma: scalar
|
||
focal loss的原公式 -aplha*(1-pt)**gmma*log(pt)
|
||
Returns:
|
||
|
||
"""
|
||
losses = F.cross_entropy(y_pred, y_true, ignore_index=ignore_index, reduction='none')
|
||
with torch.no_grad():##.你tm的为啥torch.no_grad()这还怎么在训练的时候用 ??难道是为了节省显存??我不理解
|
||
p = y_pred.softmax(dim=1)
|
||
modulating_factor = ((1 +eps- p).pow(gamma))*alpha #权重
|
||
valid_mask = ~ y_true.eq(ignore_index)
|
||
masked_y_true = torch.where(valid_mask, y_true, torch.zeros_like(y_true)) # 我觉得这个可能还是为了去选择那些不是ignorindex的去参与计算
|
||
#不对把,你tm把255的也弄成0了,tm这不就和背景一样了吗 但是好像在前边算cross_entropy时候就没有计算ignor的既是把它的loss做0
|
||
'''
|
||
torch.where(condition,a,b)其中
|
||
输入参数condition:条件限制,如果满足条件,则选择a,否则选择b作为输出
|
||
'''
|
||
#这个是干啥的..
|
||
modulating_factor = torch.gather(modulating_factor, dim=1, index=masked_y_true.unsqueeze(dim=1)).squeeze_(dim=1)
|
||
'''
|
||
torch.gather:
|
||
index实际上是索引,具体是行还是列的索引要看前面dim 的指定,index的大小就是输出的大小
|
||
举个例子:
|
||
输入:[[1,2,3],[4,5,6] index [[0 ,1],[2 ,0]] dim=1 输出:[[1 ,2],[6,4]]
|
||
gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。
|
||
|
||
|
||
在这里dim1就是类别方向比如batch,channel,h,w dim1就是channel方向
|
||
torch.gather相当于以类别为索引把modulating_factor对应的取出来
|
||
这里的
|
||
'''
|
||
scale = 1.
|
||
if normalize:
|
||
scale = losses.sum() / (losses * modulating_factor).sum()
|
||
losses = scale * (losses * modulating_factor).sum() / (valid_mask.sum() + p.size(0)) # focalloss里的对于困难样本的加权
|
||
#我不太明白为什么这里/的时候嗨哟啊+batch 这个valid_mask.sum()不久包含了batch在内的所有的样本点吗
|
||
|
||
return losses
|
||
|
||
|
||
|
||
#用类实现了一边
|
||
class FocalLoss(nn.Module):
|
||
def __init__(self,alpha=0.25,gamma=2,eps=1e-8,ignore_index=255):
|
||
super(FocalLoss, self).__init__()
|
||
self.alpha=alpha
|
||
self.gamma=gamma
|
||
self.eps=eps
|
||
self.ignore_index=ignore_index
|
||
|
||
def forward(self,y_pred, y_true):
|
||
losses = F.cross_entropy(y_pred, y_true, ignore_index=self.ignore_index, reduction='none')
|
||
p = y_pred.softmax(dim=1)
|
||
modulating_factor = ((1 + self.eps - p).pow(self.gamma)) * self.alpha # 权重
|
||
valid_mask = ~ y_true.eq(self.ignore_index)
|
||
masked_y_true = torch.where(valid_mask, y_true, torch.zeros_like(y_true)) # 我觉得这个可能还是为了去选择那些不是ignorindex的去参与计算
|
||
# 不对把,你tm把255的也弄成0了,tm这不就和背景一样了吗 但是好像在前边算cross_entropy时候就没有计算ignor的既是把它的loss做0
|
||
'''
|
||
torch.where(condition,a,b)其中
|
||
输入参数condition:条件限制,如果满足条件,则选择a,否则选择b作为输出
|
||
'''
|
||
# 这个是干啥的..
|
||
modulating_factor = torch.gather(modulating_factor, dim=1, index=masked_y_true.unsqueeze(dim=1)).squeeze_(dim=1)
|
||
'''
|
||
torch.gather:
|
||
index实际上是索引,具体是行还是列的索引要看前面dim 的指定,index的大小就是输出的大小
|
||
举个例子:
|
||
输入:[[1,2,3],[4,5,6] index [[0 ,1],[2 ,0]] dim=1 输出:[[1 ,2],[6,4]]
|
||
gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。
|
||
|
||
|
||
在这里dim1就是类别方向比如batch,channel,h,w dim1就是channel方向
|
||
torch.gather相当于以类别为索引把modulating_factor对应的取出来
|
||
这里的
|
||
'''
|
||
|
||
losses = (losses * modulating_factor).sum() / (valid_mask.sum() + p.size(0)) # focalloss里的对于困难样本的加权
|
||
# 我不太明白为什么这里/的时候嗨哟啊+batch 这个valid_mask.sum()不久包含了batch在内的所有的样本点吗
|
||
|
||
return losses
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def cosine_annealing(lower_bound, upper_bound, _t, _t_max):
|
||
'''
|
||
|
||
'''
|
||
return upper_bound + 0.5 * (lower_bound - upper_bound) * (math.cos(math.pi * _t / _t_max) + 1)
|
||
|
||
|
||
def poly_annealing(lower_bound, upper_bound, _t, _t_max):
|
||
factor = (1 - _t / _t_max) ** 0.9
|
||
return upper_bound + factor * (lower_bound - upper_bound)
|
||
|
||
|
||
def linear_annealing(lower_bound, upper_bound, _t, _t_max):
|
||
factor = 1 - _t / _t_max
|
||
return upper_bound + factor * (lower_bound - upper_bound)
|
||
|
||
|
||
def annealing_softmax_focalloss(y_pred, y_true, t, t_max, ignore_index=255, gamma=2.0,
|
||
annealing_function=cosine_annealing):
|
||
losses = F.cross_entropy(y_pred, y_true, ignore_index=ignore_index, reduction='none')
|
||
with torch.no_grad():
|
||
p = y_pred.softmax(dim=1)
|
||
modulating_factor = (1 - p).pow(gamma)
|
||
valid_mask = ~ y_true.eq(ignore_index) # 没有的话这个可以忽略
|
||
masked_y_true = torch.where(valid_mask, y_true, torch.zeros_like(y_true))
|
||
modulating_factor = torch.gather(modulating_factor, dim=1, index=masked_y_true.unsqueeze(dim=1)).squeeze_(dim=1)
|
||
# 截止到这里还是和focalloss那边差不多
|
||
normalizer = losses.sum() / (losses * modulating_factor).sum() # 大概就是论文种那个z 其实这个/换成h*w就成了mean loss了
|
||
scales = modulating_factor * normalizer # modulating_factor对应原文的那个(1-pi)r 这个scales对应原文的(1/z)*(1-pi)r
|
||
if t > t_max: # 大概的意思是整个训练过程比如1000轮 t_max=500轮 就是500以内我的factor就是不断呈现一个余弦函数变化,然后500-100轮 我的factor就固定了
|
||
scale = scales
|
||
else:
|
||
scale = annealing_function(1, scales, t,
|
||
t_max) # 比起一般的focalloss从头到位一直是modulating_factor我这里是动态加权从1开始一直到modulating_factor
|
||
losses = (losses * scale).sum() / (valid_mask.sum() + p.size(0))
|
||
return losses
|
||
|
||
|
||
|
||
#######ohem根据loss大小进行困难样本挖掘
|
||
#别人复现的版本1 #https://blog.csdn.net/m0_61139217/article/details/127084869
|
||
class OhemCELoss(nn.Module):
|
||
"""
|
||
Online hard example mining cross-entropy loss:在线难样本挖掘
|
||
if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
|
||
如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
|
||
那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
|
||
否则,计算前 n_min 个损失:loss = loss[:self.n_min]
|
||
"""
|
||
|
||
def __init__(self, thresh=0.7, n_min=10, ignore_lb=255, *args, **kwargs):
|
||
super(OhemCELoss, self).__init__()
|
||
# self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() # 将输入的概率 转换为loss值
|
||
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)) # 将输入的概率 转换为loss值
|
||
self.n_min = n_min
|
||
self.ignore_lb = ignore_lb
|
||
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') # 交叉熵
|
||
|
||
def forward(self, logits, labels):
|
||
N, C, H, W = logits.size()
|
||
loss = self.criteria(logits, labels).view(-1)
|
||
loss, _ = torch.sort(loss, descending=True) # 排序#从大到下(降序)
|
||
if loss[self.n_min] > self.thresh: # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
|
||
loss = loss[loss > self.thresh] #当钱min个的loss都大于阈值的时候,loss只取那些大于阈值的所有loss
|
||
else:
|
||
loss = loss[:self.n_min] #当前n_min个中有小于阈值的就取前nim个loss
|
||
return torch.mean(loss)
|
||
|
||
|
||
|
||
#别人复现的版本2#https://blog.csdn.net/qq_40035462/article/details/123448323
|
||
#原来是paddle写的 我给硬改成pytorch
|
||
class OhemCrossEntropyLoss(nn.Module):
|
||
"""
|
||
Implements the ohem cross entropy loss function.
|
||
Args:
|
||
thresh (float, optional): The threshold of ohem. Default: 0.7.
|
||
min_kept (int, optional): The min number to keep in loss computation. Default: 10000.
|
||
ignore_index (int64, optional): Specifies a target value that is ignored
|
||
and does not contribute to the input gradient. Default ``255``.
|
||
"""
|
||
|
||
def __init__(self, thresh=0.7, min_kept=10, ignore_index=255):
|
||
super(OhemCrossEntropyLoss, self).__init__()
|
||
self.thresh = thresh # 概率阈值,真实类别预测概率比阈值低的被认为是难样本
|
||
self.min_kept = min_kept # 最少用于计算损失的像素点数量
|
||
self.ignore_index = ignore_index # 忽略计算损失的标签
|
||
self.EPS = 1e-5 # 防止数值计算出错
|
||
|
||
def forward(self, logit, label):
|
||
"""
|
||
Forward computation.
|
||
Args:
|
||
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
|
||
(N, C), where C is number of classes, and if shape is more than 2D, this
|
||
is (N, C, D1, D2,..., Dk), k >= 1.
|
||
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
|
||
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
|
||
(N, D1, D2,..., Dk), k >= 1.
|
||
"""
|
||
if len(label.shape) != len(logit.shape):
|
||
label = torch.unsqueeze(label, 1)
|
||
|
||
# get the label after ohem
|
||
n, c, h, w = logit.shape
|
||
label = label.reshape((-1,))
|
||
valid_mask = (label != self.ignore_index).to(torch.int64)
|
||
num_valid = valid_mask.sum()
|
||
label = label * valid_mask #以上操作对ignor进行的将ignor(255)的index变成0
|
||
|
||
prob = F.softmax(logit, dim=1) # 计算预测的概率
|
||
prob = prob.transpose(0,1).reshape((c, -1)) #pytorc中transpose一次只能调换两个维度
|
||
|
||
if self.min_kept < num_valid and num_valid > 0: #如果最少的采样点比num_valid有效点要少
|
||
# let the value which ignored greater than 1
|
||
prob = prob + (1 - valid_mask) #让那些ignor位置(255)的概率都大于1
|
||
|
||
# get the prob of relevant label
|
||
label_onehot = F.one_hot(label, c) #维度(batch*h*w,3)有多少像素点就有多少行
|
||
|
||
label_onehot = label_onehot.transpose(1, 0) #换个维度
|
||
prob = prob * label_onehot # 真实类别对应的预测概率
|
||
prob = torch.sum(prob, dim=0) #这样就得到了对应类的prob
|
||
|
||
threshold = self.thresh
|
||
if self.min_kept > 0:
|
||
index = prob.argsort() #argsort()函数是对数组中的元素进行从小到大排序,并返回相应序列元素的数组下标
|
||
threshold_index = index[min(len(index), self.min_kept) - 1] #在 像素点的个数与最小采样之间选择最小值 threshold_index #这一步和上一步结合起来相当于取低x个最大的
|
||
threshold_index = int(threshold_index.cpu().numpy())
|
||
if prob[threshold_index] > self.thresh: #如果这里前threshold_index
|
||
threshold = prob[threshold_index]
|
||
kept_mask = (prob < threshold).to(torch.int64) # 根据阈值选择参与计算的像素点
|
||
label = label * kept_mask
|
||
valid_mask = valid_mask * kept_mask
|
||
|
||
# make the invalid region as ignore
|
||
label = label + (1 - valid_mask) * self.ignore_index #怎么又变回来了(带有255
|
||
|
||
# label = label.reshape((n, 1, h, w)) #原
|
||
# valid_mask = valid_mask.reshape((n, 1, h, w)).to(torch.float32)
|
||
label = label.reshape((n, h, w))
|
||
valid_mask = valid_mask.reshape((n, h, w)).to(torch.float32)
|
||
loss = F.cross_entropy(
|
||
logit, label, ignore_index=self.ignore_index,reduction='none')
|
||
loss = loss * valid_mask
|
||
avg_loss = torch.mean(loss) / (torch.mean(valid_mask) + self.EPS)
|
||
|
||
label.stop_gradient = True
|
||
valid_mask.stop_gradient = True
|
||
return avg_loss
|
||
|
||
def setup_seed(seed=0):
|
||
import torch
|
||
import os
|
||
import numpy as np
|
||
import random
|
||
torch.manual_seed(seed) # 为CPU设置随机种子
|
||
np.random.seed(seed) # Numpy module.
|
||
random.seed(seed) # Python random module.
|
||
############################################################################
|
||
#2022.12.8我感觉废了
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
setup_seed()
|
||
# input = torch.randn(2, 3, 4, 4, dtype=torch.float32)
|
||
input = torch.randn(1, 2, 4, 4, dtype=torch.float32)
|
||
# target = torch.randint(3, (1, 4, 4), dtype=torch.int64)
|
||
# target = torch.randint(3, (2, 4, 4), dtype=torch.int64)
|
||
target = torch.tensor(([[[1, 1, 1, 0],
|
||
[1, 0, 0, 1],
|
||
[0, 1, 0, 0],
|
||
[1, 1, 0, 0]]]))
|
||
|
||
|
||
# target = torch.tensor(([[[1, 1, 2, 1],
|
||
# [2, 0, 1, 0],
|
||
# [2, 0, 255, 2],
|
||
# [2, 2, 0, 1]],
|
||
# [[2, 2, 1, 0],
|
||
# [1, 0, 0, 0],
|
||
# [0, 255, 0, 1],
|
||
# [1, 2, 1, 1]]]))
|
||
|
||
print("---------pred----------")
|
||
print(input)
|
||
print("---------target----------")
|
||
print(target)
|
||
# print("__________softmax___________")
|
||
# soft = F.softmax(input, dim=1)
|
||
# print(soft)
|
||
# print("__________log___________")
|
||
# log = torch.log(soft)
|
||
# print(log)
|
||
# print("++++++++++++++++++++++++++++++++++++++++++++++")
|
||
|
||
|
||
# print("++++++++++++ohem1++++++++++++++++++++++++++++++++++")
|
||
# ohem1=OhemCELoss()
|
||
# ohemloss1 =ohem1(input,target)
|
||
#
|
||
# print(ohemloss1)
|
||
|
||
# print("++++++++++++ohem2++++++++++++++++++++++++++++++++++")
|
||
# ohem2 = OhemCrossEntropyLoss()
|
||
# ohemloss2 = ohem2(input, target)
|
||
#
|
||
# print(ohemloss2)
|
||
# #卧槽,这两个算出来的结果一样?????那用谁都可以???
|
||
|
||
# lossmodel=FocalLoss()
|
||
# backloss = lossmodel(input, target)
|
||
backloss = softmax_focalloss(input, target)
|
||
#
|
||
print(backloss)
|
||
|
||
|
||
#
|
||
#
|
||
|
||
|