#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : edge_loss @Author : qiqq @create_time : 2023/1/5 22:06 """ import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class EdgeSegLoss(nn.Module): def __init__(self, classes, ignore_index=255, edge_weight=1, seg_weight=1, seg_body_weight=1, att_weight=1): super(EdgeSegLoss, self).__init__() self.num_classes = classes self.ignore_index = ignore_index self.edge_weight = edge_weight self.seg_weight = seg_weight self.att_weight = att_weight self.seg_body_weight = seg_body_weight def bce2d(self, input, target): '''input是网络的输出,没有经过sigmode的。target的标签''' n, c, h, w = input.size() #c=1 代表是边缘 log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) #.view(1, -1)把原来的tensor reshape成两个维度1是第一个维度是1,-1是第二给维度由估计得来,也就是把原来的tensor拉成一个一行b*c*h*w的行向量 #以(2,1,512,512)---》(1,524288) target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) target_trans = target_t.clone() pos_index = (target_t == 1) #边缘的 (1,524288) 是boole类型 neg_index = (target_t == 0) #非边缘的 (1,524288) 是boole类型 ignore_index = (target_t > 1) #那些既不是边缘也不是背景的的 要被和忽略的 target_trans[pos_index] = 1 #(1,524288)去掉了ignor的并且把boole类型转成0 1类型的 target_trans[neg_index] = 0 pos_index = pos_index.data.cpu().numpy().astype(bool) neg_index = neg_index.data.cpu().numpy().astype(bool) ignore_index = ignore_index.data.cpu().numpy().astype(bool) #以上三个搞到cpu numpy上 weight = torch.Tensor(log_p.size()).fill_(0) #一个全为0的 (1,524288)目的是好像是为每个piex生成一个权重 weight = weight.numpy() pos_num = pos_index.sum() #所有边缘的pix数量 (在)[0,0,:,:]这张图中共61754 neg_num = neg_index.sum() #非边缘的像素数量 sum_num = pos_num + neg_num #有效的像素数量 #462534 weight[pos_index] = neg_num * 1.0 / sum_num #以边缘和背景的频率作为weight (注意是数量多的权重大,数量少的权重小) weight[neg_index] = pos_num * 1.0 / sum_num weight[ignore_index] = 0 weight = torch.from_numpy(weight).cuda() log_p = log_p.cuda() #原始的被seshape成了(1,bchw)的行向量 target_t = target_t.cuda() loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True) return loss