#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : xaiotu @Author : qiqq @create_time : 2023/6/29 8:57 """ import matplotlib.pyplot as plt import numpy as np import copy import cv2 as cv from tqdm import tqdm import os from torch.utils.data import DataLoader import argparse from torch.utils.data import Dataset from torchvision import transforms from taihuyuan_pv.dataloaders import custom_transforms as tr import torch import torch.nn.functional as F from PIL import Image '''voc数据集格式的''' class datasets_pvtaihuyuan(Dataset): def __init__(self, args, split='val', isAug=False ): super(datasets_pvtaihuyuan, self).__init__() self.args = args self.resize = args.resize self.crop_size = args.crop_size # 是单数比如256, self.flip_prob = args.flip_prob # 是0-1 self.isAug = isAug self._base_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" self._image_dir = os.path.join(self._base_dir, 'images') self._cat_dir = os.path.join(self._base_dir, 'labels') if isinstance(split, str): self.split = [split] else: split.sort() self.split = split _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" self.im_ids = [] self.images = [] self.categories = [] for splt in self.split: with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): _image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改 _cat = os.path.join(self._cat_dir, line + ".png") # 注意格式 assert os.path.isfile(_image) assert os.path.isfile(_cat) self.im_ids.append(line) self.images.append(_image) self.categories.append(_cat) assert (len(self.images) == len(self.categories)) # Display stats print('Number of images in {}: {:d}'.format(split, len(self.images))) def __len__(self): return len(self.images) def __getitem__(self, index): _img, _target, _name = self._make_img_gt_point_pair(index) sample = {'image': _img, 'label': _target, "name": _name} for split in self.split: if split == 'val': return self.transform_val(sample) def _make_img_gt_point_pair(self, index): _img = Image.open(self.images[index]).convert('RGB') _target = Image.open(self.categories[index]) name = self.images[index].split(".")[0] return _img, _target, name def transform_val(self, sample): composed_transforms = transforms.Compose([ tr.Resize(self.args.resize), tr.Normalize_simple(), tr.ToTensor()]) return composed_transforms(sample) def __str__(self): return 'linan(split=' + str(self.split) + ')' # class predictandeval(): def __init__(self, model_name="xx"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.predict_dir = model_name + "_" + "outputs/" self.fusion_path = model_name + "_" + "fusion/" self.model_name = model_name self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0] self.fusin = True def get_confusion_matrix(self, gt_label, pred_label, class_num): """ Calcute the confusion matrix by given label and pred :param gt_label: the ground truth label :param pred_label: the pred label :param class_num: the nunber of class :return: the confusion matrix """ index = (gt_label * class_num + pred_label).astype('int32') label_count = np.bincount(index) confusion_matrix = np.zeros((class_num, class_num)) for i_label in range(class_num): for i_pred_label in range(class_num): cur_index = i_label * class_num + i_pred_label if cur_index < len(label_count): confusion_matrix[i_label, i_pred_label] = label_count[cur_index] return confusion_matrix # 延用训练数据集格式 # def get_datasets(self, ): # parser = argparse.ArgumentParser() # args = parser.parse_args() # args.resize = (512, 512) # args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错 # args.flip_prob = 0.5 # datasets = datasets_pvtaihuyuan(args, split='val', isAug=False) # print(len(datasets)) # return datasets # 重写 voc格式的val或者test def get_images(self, ): basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" split = 'val' _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" imglist = [] with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): name = line _imagepath = os.path.join(basedir, 'images', line + ".png") assert os.path.isfile(_imagepath) image = Image.open(_imagepath) orininal_h = image.size[1] orininal_w = image.size[0] item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image} imglist.append(item) print("共监测到{}张原始图像和标签".format(len(imglist))) return imglist def get_labels(self, ): basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" split = 'val' _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" labellist = [] with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): name = line _labelpath = os.path.join(basedir, 'labels', line + ".png") assert os.path.isfile(_labelpath) label = Image.open(_labelpath) item = {"name": name, "label": label} labellist.append(item) print("共监测到{}张标签".format(len(labellist))) return labellist # def get_result1(self,net): # # datase = self.get_datasets1() # dataloader = DataLoader(datase, batch_size=1, shuffle=False, num_workers=0) # for ii, sample in enumerate(dataloader): # image, mask_true = sample['image'], sample['label'] # name=sample["name"] # image = image.to(device=self.device, dtype=torch.float32) # orininal_h, orininal_w =image.shape[2:] # with torch.no_grad(): # out = net(image) # # if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) # out = out[0] # 就取第一个 # pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy() # result = pr.argmax(axis=-1) # if not os.path.exists(self.predict_dir): # os.makedirs(self.predict_dir) # output_im = Image.fromarray(np.uint8(result)).convert('P') # # output_im.putpalette(self.palette) # output_im.save(self.predict_dir + name + '.png') # # # hunhe #没有fusion def get_result(self, net): imglist = self.get_images() assert len(imglist) != 0 for i in tqdm(imglist): image = i["image"] name = i["name"] orininal_w = i["orininal_w"] orininal_h = i["orininal_h"] old_img = copy.deepcopy(image) imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR) image_data = np.expand_dims( np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) image = images.to(device="cuda", dtype=torch.float32) model = net.to(device="cuda") out = model(image) if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) out = out[0] # 就取第一个 out = out[0] # 去掉batch pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy() result = pr.argmax(axis=-1) # 结果图 if not os.path.exists(self.predict_dir): os.makedirs(self.predict_dir) output_im = Image.fromarray(np.uint8(result)).convert('P') output_im.putpalette(self.palette) output_im.save(self.predict_dir + name + '.png') # hunhe if self.fusin: if not os.path.exists(self.fusion_path): os.makedirs(self.fusion_path) PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)] seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])], [orininal_h, orininal_w, -1]) image0 = Image.fromarray(np.uint8(seg_img0)) fusion = Image.blend(old_img, image0, 0.4) fusion.save(self.fusion_path + self.model_name + "_" + name + '.png') def preprocess_input(self, image, md=False): mean = (0.231, 0.217, 0.22) # 针对北京的数据集 std = (0.104, 0.086, 0.085) if md: image /= 255.0 image -= mean image /= std return image else: image /= 255.0 return image def compute_evalue(self, predict_dir, ): labellist = self.get_labels() # 存的额是没转换成nmpy的png predictresultlist = os.listdir(predict_dir) assert len(labellist) == len(predictresultlist) num_classes = 2 confusion_matrix = np.zeros((num_classes, num_classes)) for i in tqdm(labellist): name = i["name"] seg_gt = np.array(i["label"]) seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png"))) # ignore_index = seg_gt != 255 # seg_gt = seg_gt[ignore_index] seg_pred = seg_pred[ignore_index] confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes) pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn) res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp) tp = np.diag(confusion_matrix) IU_array = (tp / np.maximum(1.0, pos + res - tp)) mean_IU = IU_array.mean() pv_iou = IU_array[1:].mean() precion = tp / res recall = tp / pos f1 = (2 * precion * recall) / (precion + recall) mf1 = f1.mean() acc_global = tp.sum() / pos.sum() pv_f1 = f1[1:].mean() pv_precion = precion[1:].mean() pv_recall = recall[1:].mean() print("测试结果") print("acc_global:", round(acc_global, 4)) print("IU_array:", np.round(IU_array, 4)) print("precion:", np.round(precion, 4)) print("recall:", np.round(recall, 4)) print("f1:", np.round(f1, 4)) print("miou:", round(mean_IU, 4)) print("pv_iou:", round(pv_iou, 4)) print("pv_precion:", round(pv_precion, 4)) print("pv_recall:", round(pv_recall, 4)) print("pv_f1:", round(pv_f1, 4)) return pv_iou, pv_f1, pv_precion, pv_recall, acc_global def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs') parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8, help='Train_Batch size') parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1, help='Val_Batch size') parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3, help='Learning rate', dest='lr') parser.add_argument('--load', '-f', type=str, default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth", help='Load model from a .pth file') # 有没有预训练。。 parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255, help='ignore index defult 100') parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸') parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume') parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice') parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错') parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input') parser.add_argument('--cfg', type=str, default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) return parser.parse_args() if __name__ == '__main__': from taihuyuan_pv.compared_experiment.pspnet.model.pspnet import PSPNet model = PSPNet(num_classes=2,downsample_factor=8,) model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/pspnet/train_val_test/checkpoints/2023-08-09-10.24/best.pth" model_dict = torch.load(model_path) model.load_state_dict(model_dict['net']) print("权重加载") model.eval() model.cuda() testss = predictandeval(model_name="pspnetthypv") testss.get_result(net=model) predictoutpath = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/pspnet/test/pspnetthypv_outputs/" pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath) resultsavepath = "./pspnetthypv.txt" with open(resultsavepath, "a") as f: f.write("测试结果\n") f.write(f"acc_global:{round(acc_global, 4)}\n") f.write(f"pv_iou:{round(pv_iou, 4)}\n") f.write(f"pv_f1:{round(pv_f1, 4)}\n") f.write(f"pv_precion:{round(pv_precion, 4)}\n") f.write(f"pv_recall:{round(pv_recall, 4)}\n") print("写入完成")