''' Datasets This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets ''' import os import os.path import sys from PIL import Image import numpy as np from tqdm import tqdm, trange import torchvision.datasets as dset import torchvision.transforms as transforms from torchvision.datasets.utils import download_url, check_integrity import torch.utils.data as data from torch.utils.data import DataLoader IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] def is_image_file(filename): """Checks if a file is an image. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, class_to_idx): images = [] dir = os.path.expanduser(dir) for target in tqdm(sorted(os.listdir(dir))): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if is_image_file(fname): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def accimage_loader(path): import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path): from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/dogball/xxx.png root/dogball/xxy.png root/dogball/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, transform=None, target_transform=None, loader=default_loader, load_in_mem=False, index_filename='imagenet_imgs.npz', **kwargs): classes, class_to_idx = find_classes(root) # Load pre-computed image directory walk if os.path.exists(index_filename): print('Loading pre-saved Index file %s...' % index_filename) imgs = np.load(index_filename)['imgs'] # If first time, walk the folder directory and save the # results to a pre-computed file. else: print('Generating Index file %s...' % index_filename) imgs = make_dataset(root, class_to_idx) np.savez_compressed(index_filename, **{'imgs' : imgs}) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader self.load_in_mem = load_in_mem if self.load_in_mem: print('Loading all images into memory...') self.data, self.labels = [], [] for index in tqdm(range(len(self.imgs))): path, target = imgs[index][0], imgs[index][1] self.data.append(self.transform(self.loader(path))) self.labels.append(target) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ if self.load_in_mem: img = self.data[index] target = self.labels[index] else: path, target = self.imgs[index] img = self.loader(str(path)) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) # print(img.size(), target) return img, int(target) def __len__(self): return len(self.imgs) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid having to load individual images all the time. ''' import h5py as h5 import torch class ILSVRC_HDF5(data.Dataset): def __init__(self, root, transform=None, target_transform=None, load_in_mem=False, train=True,download=False, validate_seed=0, val_split=0, **kwargs): # last four are dummies self.root = root self.num_imgs = len(h5.File(root, 'r')['labels']) # self.transform = transform self.target_transform = target_transform # Set the transform here self.transform = transform # load the entire dataset into memory? self.load_in_mem = load_in_mem # If loading into memory, do so now if self.load_in_mem: print('Loading %s into memory...' % root) with h5.File(root,'r') as f: self.data = f['imgs'][:] self.labels = f['labels'][:] def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ # If loaded the entire dataset in RAM, get image from memory if self.load_in_mem: img = self.data[index] target = self.labels[index] # Else load it from disk else: with h5.File(self.root,'r') as f: img = f['imgs'][index] target = f['labels'][index] # if self.transform is not None: # img = self.transform(img) # Apply my own transform img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 if self.target_transform is not None: target = self.target_transform(target) return img, int(target) def __len__(self): return self.num_imgs # return len(self.f['imgs']) import pickle class CIFAR10(dset.CIFAR10): def __init__(self, root, train=True, transform=None, target_transform=None, download=True, validate_seed=0, val_split=0, load_in_mem=True, **kwargs): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set self.val_split = val_split if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') # now load the picked numpy arrays self.data = [] self.labels= [] for fentry in self.train_list: f = fentry[0] file = os.path.join(self.root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.labels += entry['labels'] else: self.labels += entry['fine_labels'] fo.close() self.data = np.concatenate(self.data) # Randomly select indices for validation if self.val_split > 0: label_indices = [[] for _ in range(max(self.labels)+1)] for i,l in enumerate(self.labels): label_indices[l] += [i] label_indices = np.asarray(label_indices) # randomly grab 500 elements of each class np.random.seed(validate_seed) self.val_indices = [] for l_i in label_indices: self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) if self.train=='validate': self.data = self.data[self.val_indices] self.labels = list(np.asarray(self.labels)[self.val_indices]) self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC elif self.train: print(np.shape(self.data)) if self.val_split > 0: self.data = np.delete(self.data,self.val_indices,axis=0) self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC else: f = self.test_list[0][0] file = os.path.join(self.root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.data = entry['data'] if 'labels' in entry: self.labels = entry['labels'] else: self.labels = entry['fine_labels'] fo.close() self.data = self.data.reshape((10000, 3, 32, 32)) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) class CIFAR100(CIFAR10): base_folder = 'cifar-100-python' url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'], ] test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ]