ai-station-code/wudingpv/taihuyuan_roof/compared_experiment/unet/test/xaiotu.py

266 lines
9.4 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : xaiotu
@Author : qiqq
@create_time : 2023/6/29 8:57
"""
import matplotlib.pyplot as plt
import numpy as np
import copy
import cv2 as cv
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
import argparse
from torch.utils.data import Dataset
from torchvision import transforms
from taihuyuan_roof.dataloaders import custom_transforms as tr
import torch
import torch.nn.functional as F
from PIL import Image
'''voc数据集格式的'''
class predictandeval():
def __init__(self, model_name="xx"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.predict_dir = model_name + "_" + "outputs/"
self.fusion_path = model_name + "_" + "fusion/"
self.model_name = model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.fusin = True
def get_confusion_matrix(self, gt_label, pred_label, class_num):
"""
Calcute the confusion matrix by given label and pred
:param gt_label: the ground truth label
:param pred_label: the pred label
:param class_num: the nunber of class
:return: the confusion matrix
"""
index = (gt_label * class_num + pred_label).astype('int32')
label_count = np.bincount(index)
confusion_matrix = np.zeros((class_num, class_num))
for i_label in range(class_num):
for i_pred_label in range(class_num):
cur_index = i_label * class_num + i_pred_label
if cur_index < len(label_count):
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
return confusion_matrix
# 延用训练数据集格式
# def get_datasets(self, ):
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
# args.resize = (512, 512)
# args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
# args.flip_prob = 0.5
# datasets = datasets_pvtaihuyuan(args, split='val', isAug=False)
# print(len(datasets))
# return datasets
# 重写 voc格式的val或者test
def get_images(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
imglist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_imagepath = os.path.join(basedir, 'images1', line + ".png")
assert os.path.isfile(_imagepath)
image = Image.open(_imagepath)
orininal_h = image.size[1]
orininal_w = image.size[0]
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
imglist.append(item)
print("共监测到{}张原始图像和标签".format(len(imglist)))
return imglist
def get_labels(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
labellist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_labelpath = os.path.join(basedir, 'labels1', line + ".png")
assert os.path.isfile(_labelpath)
label = Image.open(_labelpath)
item = {"name": name, "label": label}
labellist.append(item)
print("共监测到{}张标签".format(len(labellist)))
return labellist
def get_result(self, net):
imglist = self.get_images()
assert len(imglist) != 0
for i in tqdm(imglist):
image = i["image"]
name = i["name"]
orininal_w = i["orininal_w"]
orininal_h = i["orininal_h"]
old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net.to(device="cuda")
out = model(image)
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
# hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
def preprocess_input(self, image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
def compute_evalue(self, predict_dir, ):
labellist = self.get_labels() # 存的额是没转换成nmpy的png
predictresultlist = os.listdir(predict_dir)
assert len(labellist) == len(predictresultlist)
num_classes = 2
confusion_matrix = np.zeros((num_classes, num_classes))
for i in tqdm(labellist):
name = i["name"]
seg_gt = np.array(i["label"])
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
#
ignore_index = seg_gt != 255 #
seg_gt = seg_gt[ignore_index]
seg_pred = seg_pred[ignore_index]
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
tp = np.diag(confusion_matrix)
IU_array = (tp / np.maximum(1.0, pos + res - tp))
mean_IU = IU_array.mean()
roof_iou = IU_array[1:].mean()
precion = tp / res
recall = tp / pos
f1 = (2 * precion * recall) / (precion + recall)
mf1 = f1.mean()
acc_global = tp.sum() / pos.sum()
roof_f1 = f1[1:].mean()
roof_precison = precion[1:].mean()
roof_recall = recall[1:].mean()
print("测试结果")
print("acc_global:", round(acc_global, 4))
print("IU_array:", np.round(IU_array, 4))
print("precion:", np.round(precion, 4))
print("recall:", np.round(recall, 4))
print("f1:", np.round(f1, 4))
print("miou:", round(mean_IU, 4))
print("roof_iou:", round(roof_iou, 4))
print("roof_precison:", round(roof_precison, 4))
print("roof_recall:", round(roof_recall, 4))
print("roof_f1:", round(roof_f1, 4))
return roof_iou, roof_f1, roof_precison, roof_recall, acc_global
if __name__ == '__main__':
from taihuyuan_roof.compared_experiment.unet.model.unet_model import UNet
model = UNet(n_channels=3,n_classes=2)
model_path = "/home/qiqq/q3dl/codeR/applicationprojectR/taihuyuan_roof/compared_experiment/unet/train_val_test/checkpoints/2024-01-27-20.48/best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
testss = predictandeval(model_name="unetthyroof")
testss.get_result(net=model)
predictoutpath = "./unetthyroof_outputs"
roof_iou, roof_f1, roof_precison, roof_recall, acc_global = testss.compute_evalue(predictoutpath)
resultsavepath="./unetthyroof.txt"
with open(resultsavepath, "a") as f:
f.write("测试结果\n")
f.write(f"acc_global:{round(acc_global, 4)}\n")
f.write(f"roof_iou:{round(roof_iou, 4)}\n")
f.write(f"roof_f1:{round(roof_f1, 4)}\n")
f.write(f"roof_precison:{round(roof_precison, 4)}\n")
f.write(f"roof_recall:{round(roof_recall, 4)}\n")
print("写入完成")