ai-station-code/wudingpv/taihuyuan_roof/compared_experiment/imdeeplab3p/test/xaiotu.py

295 lines
12 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : xaiotu
@Author : qiqq
@create_time : 2023/6/29 8:57
"""
import numpy as np
import copy
import cv2 as cv
from tqdm import tqdm
import os
import argparse
import torch
import torch.nn.functional as F
from PIL import Image
'''voc数据集格式的'''
class predictandeval():
def __init__(self, model_name="xx"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.predict_dir = model_name + "_" + "outputs/"
self.fusion_path = model_name + "_" + "fusion/"
self.model_name = model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.fusin = True
def get_confusion_matrix(self, gt_label, pred_label, class_num):
"""
Calcute the confusion matrix by given label and pred
:param gt_label: the ground truth label
:param pred_label: the pred label
:param class_num: the nunber of class
:return: the confusion matrix
"""
index = (gt_label * class_num + pred_label).astype('int32')
label_count = np.bincount(index)
confusion_matrix = np.zeros((class_num, class_num))
for i_label in range(class_num):
for i_pred_label in range(class_num):
cur_index = i_label * class_num + i_pred_label
if cur_index < len(label_count):
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
return confusion_matrix
# 延用训练数据集格式
# def get_datasets(self, ):
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
# args.resize = (512, 512)
# args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
# args.flip_prob = 0.5
# datasets = datasets_pvtaihuyuan(args, split='val', isAug=False)
# print(len(datasets))
# return datasets
# 重写 voc格式的val或者test
def get_images(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan1/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan1/"
imglist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_imagepath = os.path.join(basedir, 'images1', line + ".png")
assert os.path.isfile(_imagepath)
image = Image.open(_imagepath)
orininal_h = image.size[1]
orininal_w = image.size[0]
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
imglist.append(item)
print("共监测到{}张原始图像和标签".format(len(imglist)))
return imglist
def get_labels(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan1/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan1/"
labellist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_labelpath = os.path.join(basedir, 'labels1', line + ".png")
assert os.path.isfile(_labelpath)
label = Image.open(_labelpath)
item = {"name": name, "label": label}
labellist.append(item)
print("共监测到{}张标签".format(len(labellist)))
return labellist
def get_result(self, net):
imglist = self.get_images()
assert len(imglist) != 0
for i in tqdm(imglist):
image = i["image"]
name = i["name"]
orininal_w = i["orininal_w"]
orininal_h = i["orininal_h"]
old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net.to(device="cuda")
out = model(image)
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
# hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
def preprocess_input(self, image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
def compute_evalue(self, predict_dir, ):
labellist = self.get_labels() # 存的额是没转换成nmpy的png
predictresultlist = os.listdir(predict_dir)
assert len(labellist) == len(predictresultlist)
num_classes = 2
confusion_matrix = np.zeros((num_classes, num_classes))
for i in tqdm(labellist):
name = i["name"]
seg_gt = np.array(i["label"])
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
#
ignore_index = seg_gt != 255 #
seg_gt = seg_gt[ignore_index]
seg_pred = seg_pred[ignore_index]
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
tp = np.diag(confusion_matrix)
IU_array = (tp / np.maximum(1.0, pos + res - tp))
mean_IU = IU_array.mean()
roof_iou = IU_array[1:].mean()
precion = tp / res
recall = tp / pos
f1 = (2 * precion * recall) / (precion + recall)
mf1 = f1.mean()
acc_global = tp.sum() / pos.sum()
roof_f1 = f1[1:].mean()
roof_precison = precion[1:].mean()
roof_recall = recall[1:].mean()
print("测试结果")
print("acc_global:", round(acc_global, 4))
print("IU_array:", np.round(IU_array, 4))
print("precion:", np.round(precion, 4))
print("recall:", np.round(recall, 4))
print("f1:", np.round(f1, 4))
print("miou:", round(mean_IU, 4))
print("roof_iou:", round(roof_iou, 4))
print("roof_precison:", round(roof_precison, 4))
print("roof_recall:", round(roof_recall, 4))
print("roof_f1:", round(roof_f1, 4))
return roof_iou, roof_f1, roof_precison, roof_recall, acc_global
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
help='Train_Batch size')
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
help='Val_Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str,
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
help='Load model from a .pth file') # 有没有预训练。。
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
help='ignore index defult 100')
parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸')
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add_argument('--num_classes', type=int,
default=2, help='output channel of network')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--cfg', type=str,
default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml",
metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
return parser.parse_args()
if __name__ == '__main__':
from taihuyuan_roof.compared_experiment.imdeeplab3p.model.modeling import imdeeplabv3plus_resnet50
model = imdeeplabv3plus_resnet50(num_classes=2,output_stride=8,pretrained_backbone=True)
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_roof/imdeeplab3p/train_val_test/checkpoints/2023-06-28-19.43/best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
testss = predictandeval(model_name="d3pmlcamamthyroof")
testss.get_result(net=model)
predictoutpath = "./d3pmlcamamthyroof_outputs"
roof_iou, roof_f1, roof_precison, roof_recall, acc_global = testss.compute_evalue(predictoutpath)
resultsavepath = "./d3pmlcamamthyroof.txt"
with open(resultsavepath, "a") as f:
f.write("测试结果\n")
f.write("上一个论文的算法(多级上下文聚合于注意力机制)\n")
f.write(f"acc_global:{round(acc_global, 4)}\n")
f.write(f"roof_iou:{round(roof_iou, 4)}\n")
f.write(f"roof_f1:{round(roof_f1, 4)}\n")
f.write(f"roof_precison:{round(roof_precison, 4)}\n")
f.write(f"roof_recall:{round(roof_recall, 4)}\n")
print("写入完成")