ai-station-code/wudingpv/taihuyuan_roof/schedulers/edge_loss.py

66 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : edge_loss
@Author : qiqq
@create_time : 2023/1/5 22:06
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class EdgeSegLoss(nn.Module):
def __init__(self, classes, ignore_index=255,
edge_weight=1, seg_weight=1, seg_body_weight=1, att_weight=1):
super(EdgeSegLoss, self).__init__()
self.num_classes = classes
self.ignore_index = ignore_index
self.edge_weight = edge_weight
self.seg_weight = seg_weight
self.att_weight = att_weight
self.seg_body_weight = seg_body_weight
def bce2d(self, input, target):
'''input是网络的输出没有经过sigmode的。target的标签'''
n, c, h, w = input.size() #c=1 代表是边缘
log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) #.view(1, -1)把原来的tensor reshape成两个维度1是第一个维度是1-1是第二给维度由估计得来也就是把原来的tensor拉成一个一行b*c*h*w的行向量 #以21512512---》1524288
target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
target_trans = target_t.clone()
pos_index = (target_t == 1) #边缘的 (1,524288) 是boole类型
neg_index = (target_t == 0) #非边缘的 (1,524288) 是boole类型
ignore_index = (target_t > 1) #那些既不是边缘也不是背景的的 要被和忽略的
target_trans[pos_index] = 1 #1524288去掉了ignor的并且把boole类型转成0 1类型的
target_trans[neg_index] = 0
pos_index = pos_index.data.cpu().numpy().astype(bool)
neg_index = neg_index.data.cpu().numpy().astype(bool)
ignore_index = ignore_index.data.cpu().numpy().astype(bool) #以上三个搞到cpu numpy上
weight = torch.Tensor(log_p.size()).fill_(0) #一个全为0的 1524288目的是好像是为每个piex生成一个权重
weight = weight.numpy()
pos_num = pos_index.sum() #所有边缘的pix数量 (在)[0,0,:,:]这张图中共61754
neg_num = neg_index.sum() #非边缘的像素数量
sum_num = pos_num + neg_num #有效的像素数量 #462534
weight[pos_index] = neg_num * 1.0 / sum_num #以边缘和背景的频率作为weight (注意是数量多的权重大,数量少的权重小)
weight[neg_index] = pos_num * 1.0 / sum_num
weight[ignore_index] = 0
weight = torch.from_numpy(weight).cuda()
log_p = log_p.cuda() #原始的被seshape成了1bchw的行向量
target_t = target_t.cuda()
loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True)
return loss