40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
import os
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, image_dir, transform=None):
|
|
self.image_dir = image_dir
|
|
self.transform = transform
|
|
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') or f.endswith('.jpg')]
|
|
|
|
def __len__(self):
|
|
return len(self.image_files)
|
|
|
|
def __getitem__(self, idx):
|
|
img_name = os.path.join(self.image_dir, self.image_files[idx])
|
|
image = Image.open(img_name).convert("L") # Assuming the mask is a grayscale image
|
|
label = np.array(image)
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
label = torch.tensor(label, dtype=torch.long)
|
|
return image, label
|
|
|
|
# 数据集和数据加载器示例
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
# 其他数据预处理步骤
|
|
])
|
|
|
|
dataset = CustomDataset(image_dir=r'data/LoveDA/Val/pv/masks_png_convert', transform=transform)
|
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
|
|
|
for images, labels in dataloader:
|
|
assert labels.max().item() <= 1 and labels.min().item() >= 0, "标签值超出范围"
|
|
print("标签值正常")
|