import csv import os from os.path import join import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from PIL import Image def f_score(inputs, target, beta=1, smooth=1e-5, threhold=0.5): n, c, h, w = inputs.size() nt, ht, wt, ct = target.size() if h != ht and w != wt: inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1) temp_target = target.view(n, -1, ct) # --------------------------------------------# # 计算dice系数 # --------------------------------------------# temp_inputs = torch.gt(temp_inputs, threhold).float() tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1]) fp = torch.sum(temp_inputs, axis=[0, 1]) - tp fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) score = torch.mean(score) return score # 设标签宽W,长H def fast_hist(a, b, n): # --------------------------------------------------------------------------------# # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) # --------------------------------------------------------------------------------# k = (a >= 0) & (a < n) # --------------------------------------------------------------------------------# # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) # 返回中,写对角线上的为分类正确的像素点 # --------------------------------------------------------------------------------# return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) def per_class_iu(hist): return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) def per_class_PA_Recall(hist): return np.diag(hist) / np.maximum(hist.sum(1), 1) def per_class_Precision(hist): return np.diag(hist) / np.maximum(hist.sum(0), 1) def per_Accuracy(hist): return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): print('Num classes', num_classes) hist = np.zeros((num_classes, num_classes)) gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] for ind in range(len(gt_imgs)): pred = np.array(Image.open(pred_imgs[ind])) label = np.array(Image.open(gt_imgs[ind])) if len(label.flatten()) != len(pred.flatten()): print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind], pred_imgs[ind])) continue hist += fast_hist(label.flatten(), pred.flatten(), num_classes) if name_classes is not None and ind > 0 and ind % 10 == 0: print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( ind, len(gt_imgs), 100 * np.nanmean(per_class_iu(hist)), 100 * np.nanmean(per_class_PA_Recall(hist)), 100 * per_Accuracy(hist) )) IoUs = per_class_iu(hist) PA_Recall = per_class_PA_Recall(hist) Precision = per_class_Precision(hist) Accuracy = per_Accuracy(hist) # Calculate OA, Recall, and mf1 (mean F1 score) for each class OA = np.sum(np.diag(hist)) / np.sum(hist) # Overall Accuracy Recall = np.diag(hist) / np.maximum(hist.sum(1), 1) # Recall for each class Precision_for_f1 = np.diag(hist) / np.maximum(hist.sum(0), 1) # Precision for each class F1_scores = 2 * (Precision_for_f1 * Recall) / np.maximum((Precision_for_f1 + Recall), 1) # F1 for each class mf1 = np.nanmean(F1_scores) # mean F1 score # Print per-class results including OA, Recall, and mf1 if name_classes is not None: for ind_class in range(num_classes): print(f'===> {name_classes[ind_class]}: ' f'Iou-{round(IoUs[ind_class] * 100, 2)}%; ' f'PA-{round(PA_Recall[ind_class] * 100, 2)}%; ' f'Precision-{round(Precision[ind_class] * 100, 2)}%; ' f'Recall-{round(Recall[ind_class] * 100, 2)}%; ' f'F1-{round(F1_scores[ind_class] * 100, 2)}%') # Print overall results print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str( round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(Accuracy * 100, 2)) + '; OA: ' + str(round(OA * 100, 2)) + '; mF1: ' + str(round(mf1 * 100, 2))) return np.array(hist, np.int64), IoUs, PA_Recall, Precision, OA, Recall, F1_scores, mf1 def adjust_axes(r, t, fig, axes): bb = t.get_window_extent(renderer=r) text_width_inches = bb.width / fig.dpi current_fig_width = fig.get_figwidth() new_fig_width = current_fig_width + text_width_inches propotion = new_fig_width / current_fig_width x_lim = axes.get_xlim() axes.set_xlim([x_lim[0], x_lim[1] * propotion]) def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size=12, plt_show=True): fig = plt.gcf() axes = plt.gca() plt.barh(range(len(values)), values, color='royalblue') plt.title(plot_title, fontsize=tick_font_size + 2) plt.xlabel(x_label, fontsize=tick_font_size) plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) r = fig.canvas.get_renderer() for i, val in enumerate(values): str_val = " " + str(val) if val < 1.0: str_val = " {0:.2f}".format(val) t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') if i == (len(values) - 1): adjust_axes(r, t, fig, axes) fig.tight_layout() fig.savefig(output_path) if plt_show: plt.show() plt.close() def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size=12): draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs) * 100), "Intersection over Union", \ os.path.join(miou_out_path, "mIoU.png"), tick_font_size=tick_font_size, plt_show=True) print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Pixel Accuracy", \ os.path.join(miou_out_path, "mPA.png"), tick_font_size=tick_font_size, plt_show=False) print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Recall", \ os.path.join(miou_out_path, "Recall.png"), tick_font_size=tick_font_size, plt_show=False) print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision) * 100), "Precision", \ os.path.join(miou_out_path, "Precision.png"), tick_font_size=tick_font_size, plt_show=False) print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: writer = csv.writer(f) writer_list = [] writer_list.append([' '] + [str(c) for c in name_classes]) for i in range(len(hist)): writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) writer.writerows(writer_list) print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))