ai-station-code/wudingpv/taihuyuan_pv/schedulers/edge_loss.py

66 lines
2.8 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : edge_loss
@Author : qiqq
@create_time : 2023/1/5 22:06
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class EdgeSegLoss(nn.Module):
def __init__(self, classes, ignore_index=255,
edge_weight=1, seg_weight=1, seg_body_weight=1, att_weight=1):
super(EdgeSegLoss, self).__init__()
self.num_classes = classes
self.ignore_index = ignore_index
self.edge_weight = edge_weight
self.seg_weight = seg_weight
self.att_weight = att_weight
self.seg_body_weight = seg_body_weight
def bce2d(self, input, target):
'''input是网络的输出没有经过sigmode的。target的标签'''
n, c, h, w = input.size() #c=1 代表是边缘
log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) #.view(1, -1)把原来的tensor reshape成两个维度1是第一个维度是1-1是第二给维度由估计得来也就是把原来的tensor拉成一个一行b*c*h*w的行向量 #以21512512---》1524288
target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
target_trans = target_t.clone()
pos_index = (target_t == 1) #边缘的 (1,524288) 是boole类型
neg_index = (target_t == 0) #非边缘的 (1,524288) 是boole类型
ignore_index = (target_t > 1) #那些既不是边缘也不是背景的的 要被和忽略的
target_trans[pos_index] = 1 #1524288去掉了ignor的并且把boole类型转成0 1类型的
target_trans[neg_index] = 0
pos_index = pos_index.data.cpu().numpy().astype(bool)
neg_index = neg_index.data.cpu().numpy().astype(bool)
ignore_index = ignore_index.data.cpu().numpy().astype(bool) #以上三个搞到cpu numpy上
weight = torch.Tensor(log_p.size()).fill_(0) #一个全为0的 1524288目的是好像是为每个piex生成一个权重
weight = weight.numpy()
pos_num = pos_index.sum() #所有边缘的pix数量 (在)[0,0,:,:]这张图中共61754
neg_num = neg_index.sum() #非边缘的像素数量
sum_num = pos_num + neg_num #有效的像素数量 #462534
weight[pos_index] = neg_num * 1.0 / sum_num #以边缘和背景的频率作为weight (注意是数量多的权重大,数量少的权重小)
weight[neg_index] = pos_num * 1.0 / sum_num
weight[ignore_index] = 0
weight = torch.from_numpy(weight).cuda()
log_p = log_p.cuda() #原始的被seshape成了1bchw的行向量
target_t = target_t.cuda()
loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True)
return loss