import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import os import numpy as np import torchvision.transforms as transforms class CustomDataset(Dataset): def __init__(self, image_dir, transform=None): self.image_dir = image_dir self.transform = transform self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') or f.endswith('.jpg')] def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_name = os.path.join(self.image_dir, self.image_files[idx]) image = Image.open(img_name).convert("L") # Assuming the mask is a grayscale image label = np.array(image) if self.transform: image = self.transform(image) label = torch.tensor(label, dtype=torch.long) return image, label # 数据集和数据加载器示例 transform = transforms.Compose([ transforms.ToTensor(), # 其他数据预处理步骤 ]) dataset = CustomDataset(image_dir=r'data/LoveDA/Val/pv/masks_png_convert', transform=transform) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) for images, labels in dataloader: assert labels.max().item() <= 1 and labels.min().item() >= 0, "标签值超出范围" print("标签值正常")