ai-station-code/wudingpv/taihuyuan_roof/dataloaders/utils.py

128 lines
4.5 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import matplotlib.pyplot as plt
import numpy as np
import torch
def decode_seg_map_sequence(label_masks, dataset='pascal'):
rgb_masks = []
for label_mask in label_masks:
rgb_mask = decode_segmap(label_mask, dataset)
rgb_masks.append(rgb_mask)
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
return rgb_masks
def decode_segmap(label_mask, dataset, plot=False):
"""Decode segmentation class labelss into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial_module location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
if dataset == 'pascal' or dataset == 'coco':
n_classes = 21
label_colours = get_pascal_labels()
elif dataset =='pascal_customer':
n_classes = 6
label_colours = get_pascal_customer_labels()
elif dataset == 'cityscapes':
n_classes = 19
label_colours = get_cityscapes_labels()
elif dataset == 'Customer_rsipac':
n_classes = 2
label_colours = get_customer_labels()
else:
raise NotImplementedError
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
def encode_segmap(mask):
"""Encode segmentation label images as pascal classes
Args:
mask (np.ndarray): raw segmentation label image of dimension
(M, N, 3), in which the Pascal classes are encoded as colours.
Returns:
(np.ndarray): class map with dimensions (M,N), where the value at
a given location is the integer denoting the class index.
"""
mask = mask.astype(int)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
for ii, label in enumerate(get_pascal_labels()):
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
label_mask = label_mask.astype(int)
return label_mask
def get_cityscapes_labels():
return np.array([
[0,0,0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32]])
def get_pascal_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]])
def get_pascal_customer_labels():
"""
'''如果有对应颜色就换成对应颜色如果没有对应颜色且类别数小于21的话就默认'''
Returns:
np.ndarray with dimensions (21, 3)
"""
# return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
# [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
# [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
# [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
# [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
# [0, 64, 128]])
return np.asarray([[255 ,0, 0], [255 ,255, 0], [192 ,192, 0], [0 ,255, 0],
[128,128,128], [0, 0 ,255]])
def get_customer_labels():
return np.asarray([[0, 0, 0], [255,255,255]])