ai-station-code/wudingpv/taihuyuan_pv/dataloaders/utils.py

128 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import matplotlib.pyplot as plt
import numpy as np
import torch
def decode_seg_map_sequence(label_masks, dataset='pascal'):
rgb_masks = []
for label_mask in label_masks:
rgb_mask = decode_segmap(label_mask, dataset)
rgb_masks.append(rgb_mask)
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
return rgb_masks
def decode_segmap(label_mask, dataset, plot=False):
"""Decode segmentation class labelss into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial_module location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
if dataset == 'pascal' or dataset == 'coco':
n_classes = 21
label_colours = get_pascal_labels()
elif dataset =='pascal_customer':
n_classes = 6
label_colours = get_pascal_customer_labels()
elif dataset == 'cityscapes':
n_classes = 19
label_colours = get_cityscapes_labels()
elif dataset == 'Customer_rsipac':
n_classes = 2
label_colours = get_customer_labels()
else:
raise NotImplementedError
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
def encode_segmap(mask):
"""Encode segmentation label images as pascal classes
Args:
mask (np.ndarray): raw segmentation label image of dimension
(M, N, 3), in which the Pascal classes are encoded as colours.
Returns:
(np.ndarray): class map with dimensions (M,N), where the value at
a given location is the integer denoting the class index.
"""
mask = mask.astype(int)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
for ii, label in enumerate(get_pascal_labels()):
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
label_mask = label_mask.astype(int)
return label_mask
def get_cityscapes_labels():
return np.array([
[0,0,0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32]])
def get_pascal_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]])
def get_pascal_customer_labels():
"""
'''如果有对应颜色就换成对应颜色如果没有对应颜色且类别数小于21的话就默认'''
Returns:
np.ndarray with dimensions (21, 3)
"""
# return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
# [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
# [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
# [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
# [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
# [0, 64, 128]])
return np.asarray([[255 ,0, 0], [255 ,255, 0], [192 ,192, 0], [0 ,255, 0],
[128,128,128], [0, 0 ,255]])
def get_customer_labels():
return np.asarray([[0, 0, 0], [255,255,255]])