ai-station-code/dimaoshibie/utils/utils_metrics.py

178 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import os
from os.path import join
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def f_score(inputs, target, beta=1, smooth=1e-5, threhold=0.5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1)
temp_target = target.view(n, -1, ct)
# --------------------------------------------#
# 计算dice系数
# --------------------------------------------#
temp_inputs = torch.gt(temp_inputs, threhold).float()
tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
fp = torch.sum(temp_inputs, axis=[0, 1]) - tp
fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
score = torch.mean(score)
return score
# 设标签宽W长H
def fast_hist(a, b, n):
# --------------------------------------------------------------------------------#
# a是转化成一维数组的标签形状(H×W,)b是转化成一维数组的预测结果形状(H×W,)
# --------------------------------------------------------------------------------#
k = (a >= 0) & (a < n)
# --------------------------------------------------------------------------------#
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数返回值形状(n, n)
# 返回中,写对角线上的为分类正确的像素点
# --------------------------------------------------------------------------------#
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
def per_class_PA_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)
def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)
def per_Accuracy(hist):
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
print('Num classes', num_classes)
hist = np.zeros((num_classes, num_classes))
gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
for ind in range(len(gt_imgs)):
pred = np.array(Image.open(pred_imgs[ind]))
label = np.array(Image.open(gt_imgs[ind]))
if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()),
len(pred.flatten()), gt_imgs[ind],
pred_imgs[ind]))
continue
hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
if name_classes is not None and ind > 0 and ind % 10 == 0:
print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
ind,
len(gt_imgs),
100 * np.nanmean(per_class_iu(hist)),
100 * np.nanmean(per_class_PA_Recall(hist)),
100 * per_Accuracy(hist)
))
IoUs = per_class_iu(hist)
PA_Recall = per_class_PA_Recall(hist)
Precision = per_class_Precision(hist)
Accuracy = per_Accuracy(hist)
# Calculate OA, Recall, and mf1 (mean F1 score) for each class
OA = np.sum(np.diag(hist)) / np.sum(hist) # Overall Accuracy
Recall = np.diag(hist) / np.maximum(hist.sum(1), 1) # Recall for each class
Precision_for_f1 = np.diag(hist) / np.maximum(hist.sum(0), 1) # Precision for each class
F1_scores = 2 * (Precision_for_f1 * Recall) / np.maximum((Precision_for_f1 + Recall), 1) # F1 for each class
mf1 = np.nanmean(F1_scores) # mean F1 score
# Print per-class results including OA, Recall, and mf1
if name_classes is not None:
for ind_class in range(num_classes):
print(f'===> {name_classes[ind_class]}: '
f'Iou-{round(IoUs[ind_class] * 100, 2)}%; '
f'PA-{round(PA_Recall[ind_class] * 100, 2)}%; '
f'Precision-{round(Precision[ind_class] * 100, 2)}%; '
f'Recall-{round(Recall[ind_class] * 100, 2)}%; '
f'F1-{round(F1_scores[ind_class] * 100, 2)}%')
# Print overall results
print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(
round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(Accuracy * 100, 2)) +
'; OA: ' + str(round(OA * 100, 2)) + '; mF1: ' + str(round(mf1 * 100, 2)))
return np.array(hist, np.int64), IoUs, PA_Recall, Precision, OA, Recall, F1_scores, mf1
def adjust_axes(r, t, fig, axes):
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1] * propotion])
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size=12, plt_show=True):
fig = plt.gcf()
axes = plt.gca()
plt.barh(range(len(values)), values, color='royalblue')
plt.title(plot_title, fontsize=tick_font_size + 2)
plt.xlabel(x_label, fontsize=tick_font_size)
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
r = fig.canvas.get_renderer()
for i, val in enumerate(values):
str_val = " " + str(val)
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
if i == (len(values) - 1):
adjust_axes(r, t, fig, axes)
fig.tight_layout()
fig.savefig(output_path)
if plt_show:
plt.show()
plt.close()
def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size=12):
draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs) * 100), "Intersection over Union", \
os.path.join(miou_out_path, "mIoU.png"), tick_font_size=tick_font_size, plt_show=True)
print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Pixel Accuracy", \
os.path.join(miou_out_path, "mPA.png"), tick_font_size=tick_font_size, plt_show=False)
print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Recall", \
os.path.join(miou_out_path, "Recall.png"), tick_font_size=tick_font_size, plt_show=False)
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision) * 100), "Precision", \
os.path.join(miou_out_path, "Precision.png"), tick_font_size=tick_font_size, plt_show=False)
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
writer = csv.writer(f)
writer_list = []
writer_list.append([' '] + [str(c) for c in name_classes])
for i in range(len(hist)):
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
writer.writerows(writer_list)
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))