import matplotlib.pyplot as plt import numpy as np import torch def decode_seg_map_sequence(label_masks, dataset='pascal'): rgb_masks = [] for label_mask in label_masks: rgb_mask = decode_segmap(label_mask, dataset) rgb_masks.append(rgb_mask) rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) return rgb_masks def decode_segmap(label_mask, dataset, plot=False): """Decode segmentation class labelss into a color image Args: label_mask (np.ndarray): an (M,N) array of integer values denoting the class label at each spatial_module location. plot (bool, optional): whether to show the resulting color image in a figure. Returns: (np.ndarray, optional): the resulting decoded color image. """ if dataset == 'pascal' or dataset == 'coco': n_classes = 21 label_colours = get_pascal_labels() elif dataset =='pascal_customer': n_classes = 6 label_colours = get_pascal_customer_labels() elif dataset == 'cityscapes': n_classes = 19 label_colours = get_cityscapes_labels() elif dataset == 'Customer_rsipac': n_classes = 2 label_colours = get_customer_labels() else: raise NotImplementedError r = label_mask.copy() g = label_mask.copy() b = label_mask.copy() for ll in range(0, n_classes): r[label_mask == ll] = label_colours[ll, 0] g[label_mask == ll] = label_colours[ll, 1] b[label_mask == ll] = label_colours[ll, 2] rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) rgb[:, :, 0] = r / 255.0 rgb[:, :, 1] = g / 255.0 rgb[:, :, 2] = b / 255.0 if plot: plt.imshow(rgb) plt.show() else: return rgb def encode_segmap(mask): """Encode segmentation label images as pascal classes Args: mask (np.ndarray): raw segmentation label image of dimension (M, N, 3), in which the Pascal classes are encoded as colours. Returns: (np.ndarray): class map with dimensions (M,N), where the value at a given location is the integer denoting the class index. """ mask = mask.astype(int) label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) for ii, label in enumerate(get_pascal_labels()): label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii label_mask = label_mask.astype(int) return label_mask def get_cityscapes_labels(): return np.array([ [0,0,0], [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [0, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) def get_pascal_labels(): """Load the mapping that associates pascal classes with label colors Returns: np.ndarray with dimensions (21, 3) """ return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]) def get_pascal_customer_labels(): """ '''如果有对应颜色就换成对应颜色,如果没有对应颜色且类别数小于21的话就默认''' Returns: np.ndarray with dimensions (21, 3) """ # return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], # [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], # [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], # [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], # [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], # [0, 64, 128]]) return np.asarray([[255 ,0, 0], [255 ,255, 0], [192 ,192, 0], [0 ,255, 0], [128,128,128], [0, 0 ,255]]) def get_customer_labels(): return np.asarray([[0, 0, 0], [255,255,255]])