71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
import os
|
||
import os.path as osp
|
||
import numpy as np
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 定义数据路径和文件后缀
|
||
data_root = 'data/LoveDA/Test/pv'
|
||
img_dir = 'images_png'
|
||
mask_dir = 'masks_png_convert'
|
||
mask_suffix = '.png'
|
||
|
||
# 获取所有图像和掩码文件名
|
||
img_filenames = os.listdir(osp.join(data_root, img_dir))
|
||
mask_filenames = os.listdir(osp.join(data_root, mask_dir))
|
||
|
||
# 确保图像和掩码数量一致
|
||
assert len(img_filenames) == len(mask_filenames), "Images and masks count do not match"
|
||
|
||
# 检查每个掩码文件
|
||
for mask_filename in mask_filenames:
|
||
mask_path = osp.join(data_root, mask_dir, mask_filename)
|
||
mask = Image.open(mask_path).convert('L')
|
||
mask_array = np.array(mask)
|
||
|
||
# 找出掩码中大于1的像素值
|
||
if np.any(mask_array > 1):
|
||
print(f"Found values > 1 in mask: {mask_filename}")
|
||
|
||
# 打印掩码中大于1的像素值
|
||
unique_labels = np.unique(mask_array)
|
||
print(f"Unique labels in mask: {unique_labels}")
|
||
|
||
# 将掩码中大于1的像素值改为1
|
||
mask_array[mask_array > 1] = 1
|
||
print(f"Corrected mask values in {mask_filename}")
|
||
|
||
# 保存修改后的掩码
|
||
corrected_mask = Image.fromarray(mask_array)
|
||
corrected_mask.save(mask_path)
|
||
|
||
# 加载对应的图像
|
||
img_filename = mask_filename.replace(mask_suffix, '.jpg')
|
||
img_path = osp.join(data_root, img_dir, img_filename)
|
||
img = Image.open(img_path).convert('RGB')
|
||
|
||
# # 显示图像和掩码
|
||
# fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
||
# axes[0].imshow(img)
|
||
# axes[0].set_title('Image')
|
||
# axes[0].axis('off')
|
||
# axes[1].imshow(corrected_mask, cmap='gray')
|
||
# axes[1].set_title('Mask')
|
||
# axes[1].axis('off')
|
||
# plt.show()
|
||
else:
|
||
# 如果没有大于1的像素值,直接加载和显示图像与掩码
|
||
img_filename = mask_filename.replace(mask_suffix, '.jpg')
|
||
img_path = osp.join(data_root, img_dir, img_filename)
|
||
img = Image.open(img_path).convert('RGB')
|
||
|
||
# # 显示图像和掩码
|
||
# fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
||
# axes[0].imshow(img)
|
||
# axes[0].set_title('Image')
|
||
# axes[0].axis('off')
|
||
# axes[1].imshow(mask, cmap='gray')
|
||
# axes[1].set_title('Mask')
|
||
# axes[1].axis('off')
|
||
# plt.show()
|