#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : xaiotu @Author : qiqq @create_time : 2023/6/29 8:57 """ import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import numpy as np import copy import cv2 as cv from tqdm import tqdm import os import argparse from torch.utils.data import Dataset from torchvision import transforms from taihuyuan_pv.dataloaders import custom_transforms as tr import torch import torch.nn.functional as F from PIL import Image '''voc数据集格式的''' class datasets_pvtaihuyuan(Dataset): def __init__(self, args, split='val', isAug=False ): super(datasets_pvtaihuyuan, self).__init__() self.args = args self.resize = args.resize self.crop_size = args.crop_size # 是单数比如256, self.flip_prob = args.flip_prob # 是0-1 self.isAug = isAug self._base_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" self._image_dir = os.path.join(self._base_dir, 'images') self._cat_dir = os.path.join(self._base_dir, 'labels') if isinstance(split, str): self.split = [split] else: split.sort() self.split = split _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" self.im_ids = [] self.images = [] self.categories = [] for splt in self.split: with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): _image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改 _cat = os.path.join(self._cat_dir, line + ".png") # 注意格式 assert os.path.isfile(_image) assert os.path.isfile(_cat) self.im_ids.append(line) self.images.append(_image) self.categories.append(_cat) assert (len(self.images) == len(self.categories)) # Display stats print('Number of images in {}: {:d}'.format(split, len(self.images))) def __len__(self): return len(self.images) def __getitem__(self, index): _img, _target, _name = self._make_img_gt_point_pair(index) sample = {'image': _img, 'label': _target, "name": _name} for split in self.split: if split == 'val': return self.transform_val(sample) def _make_img_gt_point_pair(self, index): _img = Image.open(self.images[index]).convert('RGB') _target = Image.open(self.categories[index]) name = self.images[index].split(".")[0] return _img, _target, name def transform_val(self, sample): composed_transforms = transforms.Compose([ tr.Resize(self.args.resize), tr.Normalize_simple(), tr.ToTensor()]) return composed_transforms(sample) def __str__(self): return 'linan(split=' + str(self.split) + ')' # class predictandeval(): def __init__(self, model_name="xx"): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.predict_dir = model_name + "_" + "outputs/" self.fusion_path = model_name + "_" + "fusion/" self.model_name = model_name self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0] self.fusin = True def get_confusion_matrix(self, gt_label, pred_label, class_num): """ Calcute the confusion matrix by given label and pred :param gt_label: the ground truth label :param pred_label: the pred label :param class_num: the nunber of class :return: the confusion matrix """ index = (gt_label * class_num + pred_label).astype('int32') label_count = np.bincount(index) confusion_matrix = np.zeros((class_num, class_num)) for i_label in range(class_num): for i_pred_label in range(class_num): cur_index = i_label * class_num + i_pred_label if cur_index < len(label_count): confusion_matrix[i_label, i_pred_label] = label_count[cur_index] return confusion_matrix # 延用训练数据集格式 # def get_datasets(self, ): # parser = argparse.ArgumentParser() # args = parser.parse_args() # args.resize = (512, 512) # args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错 # args.flip_prob = 0.5 # datasets = datasets_pvtaihuyuan(args, split='val', isAug=False) # print(len(datasets)) # return datasets # 重写 voc格式的val或者test def get_images(self, ): basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" split = 'val' _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" imglist = [] with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): name = line _imagepath = os.path.join(basedir, 'images', line + ".png") assert os.path.isfile(_imagepath) image = Image.open(_imagepath) orininal_h = image.size[1] orininal_w = image.size[0] item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image} imglist.append(item) print("共监测到{}张原始图像和标签".format(len(imglist))) return imglist def get_labels(self, ): basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" split = 'val' _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/" labellist = [] with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f: lines = f.read().splitlines() for ii, line in enumerate(lines): name = line _labelpath = os.path.join(basedir, 'labels', line + ".png") assert os.path.isfile(_labelpath) label = Image.open(_labelpath) item = {"name": name, "label": label} labellist.append(item) print("共监测到{}张标签".format(len(labellist))) return labellist # def get_result1(self,net): # # datase = self.get_datasets1() # dataloader = DataLoader(datase, batch_size=1, shuffle=False, num_workers=0) # for ii, sample in enumerate(dataloader): # image, mask_true = sample['image'], sample['label'] # name=sample["name"] # image = image.to(device=self.device, dtype=torch.float32) # orininal_h, orininal_w =image.shape[2:] # with torch.no_grad(): # out = net(image) # # if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) # out = out[0] # 就取第一个 # pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy() # result = pr.argmax(axis=-1) # if not os.path.exists(self.predict_dir): # os.makedirs(self.predict_dir) # output_im = Image.fromarray(np.uint8(result)).convert('P') # # output_im.putpalette(self.palette) # output_im.save(self.predict_dir + name + '.png') # # # hunhe #没有fusion def get_result(self, net): imglist = self.get_images() assert len(imglist) != 0 for i in tqdm(imglist): image = i["image"] name = i["name"] orininal_w = i["orininal_w"] orininal_h = i["orininal_h"] old_img = copy.deepcopy(image) imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR) image_data = np.expand_dims( np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) image = images.to(device="cuda", dtype=torch.float32) model = net.to(device="cuda") out = model(image) if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) out = out[0] # 就取第一个 out = out[0] # 去掉batch pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy() result = pr.argmax(axis=-1) # 结果图 if not os.path.exists(self.predict_dir): os.makedirs(self.predict_dir) output_im = Image.fromarray(np.uint8(result)).convert('P') output_im.putpalette(self.palette) output_im.save(self.predict_dir + name + '.png') # hunhe if self.fusin: if not os.path.exists(self.fusion_path): os.makedirs(self.fusion_path) PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)] seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])], [orininal_h, orininal_w, -1]) image0 = Image.fromarray(np.uint8(seg_img0)) fusion = Image.blend(old_img, image0, 0.4) fusion.save(self.fusion_path + self.model_name + "_" + name + '.png') def preprocess_input(self, image, md=False): mean = (0.231, 0.217, 0.22) # 针对北京的数据集 std = (0.104, 0.086, 0.085) if md: image /= 255.0 image -= mean image /= std return image else: image /= 255.0 return image def compute_evalue(self, predict_dir, ): labellist = self.get_labels() # 存的额是没转换成nmpy的png predictresultlist = os.listdir(predict_dir) assert len(labellist) == len(predictresultlist) num_classes = 2 confusion_matrix = np.zeros((num_classes, num_classes)) for i in tqdm(labellist): name = i["name"] seg_gt = np.array(i["label"]) seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png"))) # ignore_index = seg_gt != 255 # seg_gt = seg_gt[ignore_index] seg_pred = seg_pred[ignore_index] confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes) pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn) res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp) tp = np.diag(confusion_matrix) IU_array = (tp / np.maximum(1.0, pos + res - tp)) mean_IU = IU_array.mean() pv_iou = IU_array[1:].mean() precion = tp / res recall = tp / pos f1 = (2 * precion * recall) / (precion + recall) mf1 = f1.mean() acc_global = tp.sum() / pos.sum() pv_f1 = f1[1:].mean() pv_precion = precion[1:].mean() pv_recall = recall[1:].mean() print("测试结果") print("acc_global:", round(acc_global, 4)) print("IU_array:", np.round(IU_array, 4)) print("precion:", np.round(precion, 4)) print("recall:", np.round(recall, 4)) print("f1:", np.round(f1, 4)) print("miou:", round(mean_IU, 4)) print("pv_iou:", round(pv_iou, 4)) print("pv_precion:", round(pv_precion, 4)) print("pv_recall:", round(pv_recall, 4)) print("pv_f1:", round(pv_f1, 4)) return pv_iou, pv_f1, pv_precion, pv_recall, acc_global def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs') parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8, help='Train_Batch size') parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1, help='Val_Batch size') parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3, help='Learning rate', dest='lr') parser.add_argument('--load', '-f', type=str, default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth", help='Load model from a .pth file') # 有没有预训练。。 parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255, help='ignore index defult 100') parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸') parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume') parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice') parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错') parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input') parser.add_argument('--cfg', type=str, default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) return parser.parse_args() if __name__ == '__main__': from taihuyuan_pv.compared_experiment.imdeeplab3p.model.modeling import imdeeplabv3plus_resnet50 model = imdeeplabv3plus_resnet50(num_classes=2,output_stride=8,pretrained_backbone=True) model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/imdeeplab3p/train_val_test/checkpoints/2023-07-06-09.51//best.pth" model_dict = torch.load(model_path) model.load_state_dict(model_dict['net']) print("权重加载") model.eval() model.cuda() testss = predictandeval(model_name="d3pmlcamam") testss.get_result(net=model) predictoutpath = "./d3pmlcamam_outputs/" pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath) resultsavepath = "./d3pmlcamam.txt" with open(resultsavepath, "a") as f: f.write("测试结果\n") f.write("上一个论文的算法(多级上下文聚合于注意力机制)\n") f.write(f"acc_global:{round(acc_global, 4)}\n") f.write(f"pv_iou:{round(pv_iou, 4)}\n") f.write(f"pv_f1:{round(pv_f1, 4)}\n") f.write(f"pv_precion:{round(pv_precion, 4)}\n") f.write(f"pv_recall:{round(pv_recall, 4)}\n") print("写入完成")