ai-station-code/wudingpv/taihuyuan_roof/utils/dice_coefficient_loss.py

57 lines
2.0 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import torch
import torch.nn as nn
def build_target(target: torch.Tensor, num_classes: int = 2, ignore_index: int = -100):
"""build target for dice coefficient"""
dice_target = target.clone()
if ignore_index >= 0:
ignore_mask = torch.eq(target, ignore_index)
dice_target[ignore_mask] = 0
# [N, H, W] -> [N, H, W, C]
dice_target = nn.functional.one_hot(dice_target, num_classes).float()
dice_target[ignore_mask] = ignore_index
else:
dice_target = nn.functional.one_hot(dice_target, num_classes).float()
return dice_target.permute(0, 3, 1, 2)
def dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
# Average of Dice coefficient for all batches, or for a single mask
# 计算一个batch中所有图片某个类别的dice_coefficient
d = 0.
batch_size = x.shape[0]
for i in range(batch_size):
x_i = x[i].reshape(-1) #numclass l
t_i = target[i].reshape(-1)
if ignore_index >= 0:
# 找出mask中不为ignore_index的区域
roi_mask = torch.ne(t_i, ignore_index)
x_i = x_i[roi_mask]
t_i = t_i[roi_mask]
inter = torch.dot(x_i, t_i)
sets_sum = torch.sum(x_i) + torch.sum(t_i)
if sets_sum == 0:
sets_sum = 2 * inter
d += (2 * inter + epsilon) / (sets_sum + epsilon)
return d / batch_size
def multiclass_dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
"""Average of Dice coefficient for all classes"""
dice = 0.
for channel in range(x.shape[1]):
dice += dice_coeff(x[:, channel, ...], target[:, channel, ...], ignore_index, epsilon)
return dice / x.shape[1]
def dice_loss(x: torch.Tensor, target: torch.Tensor, multiclass: bool = False, ignore_index: int = -100):
# Dice loss (objective to minimize) between 0 and 1
x = nn.functional.softmax(x, dim=1)
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(x, target, ignore_index=ignore_index)