Tan_pytorch_segmentation/pytorch_segmentation/PV_FuseDisNet/xiugai.py

71 lines
2.4 KiB
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
import os
import os.path as osp
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# 定义数据路径和文件后缀
data_root = 'data/LoveDA/Test/pv'
img_dir = 'images_png'
mask_dir = 'masks_png_convert'
mask_suffix = '.png'
# 获取所有图像和掩码文件名
img_filenames = os.listdir(osp.join(data_root, img_dir))
mask_filenames = os.listdir(osp.join(data_root, mask_dir))
# 确保图像和掩码数量一致
assert len(img_filenames) == len(mask_filenames), "Images and masks count do not match"
# 检查每个掩码文件
for mask_filename in mask_filenames:
mask_path = osp.join(data_root, mask_dir, mask_filename)
mask = Image.open(mask_path).convert('L')
mask_array = np.array(mask)
# 找出掩码中大于1的像素值
if np.any(mask_array > 1):
print(f"Found values > 1 in mask: {mask_filename}")
# 打印掩码中大于1的像素值
unique_labels = np.unique(mask_array)
print(f"Unique labels in mask: {unique_labels}")
# 将掩码中大于1的像素值改为1
mask_array[mask_array > 1] = 1
print(f"Corrected mask values in {mask_filename}")
# 保存修改后的掩码
corrected_mask = Image.fromarray(mask_array)
corrected_mask.save(mask_path)
# 加载对应的图像
img_filename = mask_filename.replace(mask_suffix, '.jpg')
img_path = osp.join(data_root, img_dir, img_filename)
img = Image.open(img_path).convert('RGB')
# # 显示图像和掩码
# fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# axes[0].imshow(img)
# axes[0].set_title('Image')
# axes[0].axis('off')
# axes[1].imshow(corrected_mask, cmap='gray')
# axes[1].set_title('Mask')
# axes[1].axis('off')
# plt.show()
else:
# 如果没有大于1的像素值直接加载和显示图像与掩码
img_filename = mask_filename.replace(mask_suffix, '.jpg')
img_path = osp.join(data_root, img_dir, img_filename)
img = Image.open(img_path).convert('RGB')
# # 显示图像和掩码
# fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# axes[0].imshow(img)
# axes[0].set_title('Image')
# axes[0].axis('off')
# axes[1].imshow(mask, cmap='gray')
# axes[1].set_title('Mask')
# axes[1].axis('off')
# plt.show()