66 lines
2.8 KiB
Python
66 lines
2.8 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : edge_loss
|
|||
|
@Author : qiqq
|
|||
|
@create_time : 2023/1/5 22:06
|
|||
|
"""
|
|||
|
|
|||
|
import logging
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
|
|||
|
|
|||
|
class EdgeSegLoss(nn.Module):
|
|||
|
def __init__(self, classes, ignore_index=255,
|
|||
|
edge_weight=1, seg_weight=1, seg_body_weight=1, att_weight=1):
|
|||
|
super(EdgeSegLoss, self).__init__()
|
|||
|
self.num_classes = classes
|
|||
|
|
|||
|
|
|||
|
self.ignore_index = ignore_index
|
|||
|
self.edge_weight = edge_weight
|
|||
|
self.seg_weight = seg_weight
|
|||
|
self.att_weight = att_weight
|
|||
|
self.seg_body_weight = seg_body_weight
|
|||
|
|
|||
|
def bce2d(self, input, target):
|
|||
|
'''input是网络的输出,没有经过sigmode的。target的标签'''
|
|||
|
n, c, h, w = input.size() #c=1 代表是边缘
|
|||
|
|
|||
|
log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) #.view(1, -1)把原来的tensor reshape成两个维度1是第一个维度是1,-1是第二给维度由估计得来,也就是把原来的tensor拉成一个一行b*c*h*w的行向量 #以(2,1,512,512)---》(1,524288)
|
|||
|
target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
|
|||
|
target_trans = target_t.clone()
|
|||
|
|
|||
|
pos_index = (target_t == 1) #边缘的 (1,524288) 是boole类型
|
|||
|
neg_index = (target_t == 0) #非边缘的 (1,524288) 是boole类型
|
|||
|
ignore_index = (target_t > 1) #那些既不是边缘也不是背景的的 要被和忽略的
|
|||
|
|
|||
|
target_trans[pos_index] = 1 #(1,524288)去掉了ignor的并且把boole类型转成0 1类型的
|
|||
|
target_trans[neg_index] = 0
|
|||
|
|
|||
|
pos_index = pos_index.data.cpu().numpy().astype(bool)
|
|||
|
neg_index = neg_index.data.cpu().numpy().astype(bool)
|
|||
|
ignore_index = ignore_index.data.cpu().numpy().astype(bool) #以上三个搞到cpu numpy上
|
|||
|
|
|||
|
weight = torch.Tensor(log_p.size()).fill_(0) #一个全为0的 (1,524288)目的是好像是为每个piex生成一个权重
|
|||
|
weight = weight.numpy()
|
|||
|
pos_num = pos_index.sum() #所有边缘的pix数量 (在)[0,0,:,:]这张图中共61754
|
|||
|
neg_num = neg_index.sum() #非边缘的像素数量
|
|||
|
sum_num = pos_num + neg_num #有效的像素数量 #462534
|
|||
|
weight[pos_index] = neg_num * 1.0 / sum_num #以边缘和背景的频率作为weight (注意是数量多的权重大,数量少的权重小)
|
|||
|
weight[neg_index] = pos_num * 1.0 / sum_num
|
|||
|
|
|||
|
weight[ignore_index] = 0
|
|||
|
|
|||
|
|
|||
|
weight = torch.from_numpy(weight).cuda()
|
|||
|
log_p = log_p.cuda() #原始的被seshape成了(1,bchw)的行向量
|
|||
|
target_t = target_t.cuda()
|
|||
|
|
|||
|
loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True)
|
|||
|
|
|||
|
return loss
|