Tan_pytorch_segmentation/pytorch_segmentation/PV_FuseDisNet/yuchuli.py

40 lines
1.3 KiB
Python

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import torchvision.transforms as transforms
class CustomDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') or f.endswith('.jpg')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_name).convert("L") # Assuming the mask is a grayscale image
label = np.array(image)
if self.transform:
image = self.transform(image)
label = torch.tensor(label, dtype=torch.long)
return image, label
# 数据集和数据加载器示例
transform = transforms.Compose([
transforms.ToTensor(),
# 其他数据预处理步骤
])
dataset = CustomDataset(image_dir=r'data/LoveDA/Val/pv/masks_png_convert', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
assert labels.max().item() <= 1 and labels.min().item() >= 0, "标签值超出范围"
print("标签值正常")