import os import os.path as osp import numpy as np from PIL import Image import matplotlib.pyplot as plt # 定义数据路径和文件后缀 data_root = 'data/LoveDA/Test/pv' img_dir = 'images_png' mask_dir = 'masks_png_convert' mask_suffix = '.png' # 获取所有图像和掩码文件名 img_filenames = os.listdir(osp.join(data_root, img_dir)) mask_filenames = os.listdir(osp.join(data_root, mask_dir)) # 确保图像和掩码数量一致 assert len(img_filenames) == len(mask_filenames), "Images and masks count do not match" # 检查每个掩码文件 for mask_filename in mask_filenames: mask_path = osp.join(data_root, mask_dir, mask_filename) mask = Image.open(mask_path).convert('L') mask_array = np.array(mask) # 找出掩码中大于1的像素值 if np.any(mask_array > 1): print(f"Found values > 1 in mask: {mask_filename}") # 打印掩码中大于1的像素值 unique_labels = np.unique(mask_array) print(f"Unique labels in mask: {unique_labels}") # 将掩码中大于1的像素值改为1 mask_array[mask_array > 1] = 1 print(f"Corrected mask values in {mask_filename}") # 保存修改后的掩码 corrected_mask = Image.fromarray(mask_array) corrected_mask.save(mask_path) # 加载对应的图像 img_filename = mask_filename.replace(mask_suffix, '.jpg') img_path = osp.join(data_root, img_dir, img_filename) img = Image.open(img_path).convert('RGB') # # 显示图像和掩码 # fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # axes[0].imshow(img) # axes[0].set_title('Image') # axes[0].axis('off') # axes[1].imshow(corrected_mask, cmap='gray') # axes[1].set_title('Mask') # axes[1].axis('off') # plt.show() else: # 如果没有大于1的像素值,直接加载和显示图像与掩码 img_filename = mask_filename.replace(mask_suffix, '.jpg') img_path = osp.join(data_root, img_dir, img_filename) img = Image.open(img_path).convert('RGB') # # 显示图像和掩码 # fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # axes[0].imshow(img) # axes[0].set_title('Image') # axes[0].axis('off') # axes[1].imshow(mask, cmap='gray') # axes[1].set_title('Mask') # axes[1].axis('off') # plt.show()