1196 lines
48 KiB
Python
1196 lines
48 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
''' Utilities file
|
|
This file contains utility functions for bookkeeping, logging, and data loading.
|
|
Methods which directly affect training should either go in layers, the model,
|
|
or train_fns.py.
|
|
'''
|
|
|
|
from __future__ import print_function
|
|
import sys
|
|
import os
|
|
import numpy as np
|
|
import time
|
|
import datetime
|
|
import json
|
|
import pickle
|
|
from argparse import ArgumentParser
|
|
import animal_hash
|
|
import datasets as dset
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
def prepare_parser():
|
|
usage = 'Parser for all scripts.'
|
|
parser = ArgumentParser(description=usage)
|
|
|
|
### Dataset/Dataloader stuff ###
|
|
parser.add_argument(
|
|
'--dataset', type=str, default='I128_hdf5',
|
|
help='Which Dataset to train on, out of I128, I256, C10, C100;'
|
|
'Append "_hdf5" to use the hdf5 version for ISLVRC '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--augment', action='store_true', default=False,
|
|
help='Augment with random crops and flips (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_workers', type=int, default=8,
|
|
help='Number of dataloader workers; consider using less for HDF5 '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--no_pin_memory', action='store_false', dest='pin_memory', default=True,
|
|
help='Pin data into memory through dataloader? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--shuffle', action='store_true', default=False,
|
|
help='Shuffle the data (strongly recommended)? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--load_in_mem', action='store_true', default=False,
|
|
help='Load all data into memory? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--use_multiepoch_sampler', action='store_true', default=False,
|
|
help='Use the multi-epoch sampler for dataloader? (default: %(default)s)')
|
|
|
|
|
|
### Model stuff ###
|
|
parser.add_argument(
|
|
'--model', type=str, default='BigGAN',
|
|
help='Name of the model module (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_param', type=str, default='SN',
|
|
help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)'
|
|
' or None (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_param', type=str, default='SN',
|
|
help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)'
|
|
' or None (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_ch', type=int, default=64,
|
|
help='Channel multiplier for G (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_ch', type=int, default=64,
|
|
help='Channel multiplier for D (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_depth', type=int, default=1,
|
|
help='Number of resblocks per stage in G? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_depth', type=int, default=1,
|
|
help='Number of resblocks per stage in D? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_thin', action='store_false', dest='D_wide', default=True,
|
|
help='Use the SN-GAN channel pattern for D? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_shared', action='store_true', default=False,
|
|
help='Use shared embeddings in G? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--shared_dim', type=int, default=0,
|
|
help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--dim_z', type=int, default=128,
|
|
help='Noise dimensionality: %(default)s)')
|
|
parser.add_argument(
|
|
'--z_var', type=float, default=1.0,
|
|
help='Noise variance: %(default)s)')
|
|
parser.add_argument(
|
|
'--hier', action='store_true', default=False,
|
|
help='Use hierarchical z in G? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--cross_replica', action='store_true', default=False,
|
|
help='Cross_replica batchnorm in G?(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--mybn', action='store_true', default=False,
|
|
help='Use my batchnorm (which supports standing stats?) %(default)s)')
|
|
parser.add_argument(
|
|
'--G_nl', type=str, default='relu',
|
|
help='Activation function for G (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_nl', type=str, default='relu',
|
|
help='Activation function for D (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_attn', type=str, default='64',
|
|
help='What resolutions to use attention on for G (underscore separated) '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_attn', type=str, default='64',
|
|
help='What resolutions to use attention on for D (underscore separated) '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--norm_style', type=str, default='bn',
|
|
help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], '
|
|
'ln [layernorm], gn [groupnorm] (default: %(default)s)')
|
|
|
|
### Model init stuff ###
|
|
parser.add_argument(
|
|
'--seed', type=int, default=0,
|
|
help='Random seed to use; affects both initialization and '
|
|
' dataloading. (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_init', type=str, default='ortho',
|
|
help='Init style to use for G (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_init', type=str, default='ortho',
|
|
help='Init style to use for D(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--skip_init', action='store_true', default=False,
|
|
help='Skip initialization, ideal for testing when ortho init was used '
|
|
'(default: %(default)s)')
|
|
|
|
### Optimizer stuff ###
|
|
parser.add_argument(
|
|
'--G_lr', type=float, default=5e-5,
|
|
help='Learning rate to use for Generator (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_lr', type=float, default=2e-4,
|
|
help='Learning rate to use for Discriminator (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_B1', type=float, default=0.0,
|
|
help='Beta1 to use for Generator (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_B1', type=float, default=0.0,
|
|
help='Beta1 to use for Discriminator (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_B2', type=float, default=0.999,
|
|
help='Beta2 to use for Generator (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_B2', type=float, default=0.999,
|
|
help='Beta2 to use for Discriminator (default: %(default)s)')
|
|
|
|
### Batch size, parallel, and precision stuff ###
|
|
parser.add_argument(
|
|
'--batch_size', type=int, default=64,
|
|
help='Default overall batchsize (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_batch_size', type=int, default=0,
|
|
help='Batch size to use for G; if 0, same as D (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_G_accumulations', type=int, default=1,
|
|
help='Number of passes to accumulate G''s gradients over '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_D_steps', type=int, default=2,
|
|
help='Number of D steps per G step (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_D_accumulations', type=int, default=1,
|
|
help='Number of passes to accumulate D''s gradients over '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--split_D', action='store_true', default=False,
|
|
help='Run D twice rather than concatenating inputs? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_epochs', type=int, default=100,
|
|
help='Number of epochs to train for (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--parallel', action='store_true', default=False,
|
|
help='Train with multiple GPUs (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_fp16', action='store_true', default=False,
|
|
help='Train with half-precision in G? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_fp16', action='store_true', default=False,
|
|
help='Train with half-precision in D? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_mixed_precision', action='store_true', default=False,
|
|
help='Train with half-precision activations but fp32 params in D? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--G_mixed_precision', action='store_true', default=False,
|
|
help='Train with half-precision activations but fp32 params in G? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--accumulate_stats', action='store_true', default=False,
|
|
help='Accumulate "standing" batchnorm stats? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_standing_accumulations', type=int, default=16,
|
|
help='Number of forward passes to use in accumulating standing stats? '
|
|
'(default: %(default)s)')
|
|
|
|
### Bookkeping stuff ###
|
|
parser.add_argument(
|
|
'--G_eval_mode', action='store_true', default=False,
|
|
help='Run G in eval mode (running/standing stats?) at sample/test time? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--save_every', type=int, default=2000,
|
|
help='Save every X iterations (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_save_copies', type=int, default=2,
|
|
help='How many copies to save (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_best_copies', type=int, default=2,
|
|
help='How many previous best checkpoints to save (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--which_best', type=str, default='IS',
|
|
help='Which metric to use to determine when to save new "best"'
|
|
'checkpoints, one of IS or FID (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--no_fid', action='store_true', default=False,
|
|
help='Calculate IS only, not FID? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--test_every', type=int, default=5000,
|
|
help='Test every X iterations (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_inception_images', type=int, default=50000,
|
|
help='Number of samples to compute inception metrics with '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--hashname', action='store_true', default=False,
|
|
help='Use a hash of the experiment name instead of the full config '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--base_root', type=str, default='',
|
|
help='Default location to store all weights, samples, data, and logs '
|
|
' (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--data_root', type=str, default='data',
|
|
help='Default location where data is stored (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--weights_root', type=str, default='weights',
|
|
help='Default location to store weights (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--logs_root', type=str, default='logs',
|
|
help='Default location to store logs (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--samples_root', type=str, default='samples',
|
|
help='Default location to store samples (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--pbar', type=str, default='mine',
|
|
help='Type of progressbar to use; one of "mine" or "tqdm" '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--name_suffix', type=str, default='',
|
|
help='Suffix for experiment name for loading weights for sampling '
|
|
'(consider "best0") (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--experiment_name', type=str, default='',
|
|
help='Optionally override the automatic experiment naming with this arg. '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--config_from_name', action='store_true', default=False,
|
|
help='Use a hash of the experiment name instead of the full config '
|
|
'(default: %(default)s)')
|
|
|
|
### EMA Stuff ###
|
|
parser.add_argument(
|
|
'--ema', action='store_true', default=False,
|
|
help='Keep an ema of G''s weights? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--ema_decay', type=float, default=0.9999,
|
|
help='EMA decay rate (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--use_ema', action='store_true', default=False,
|
|
help='Use the EMA parameters of G for evaluation? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--ema_start', type=int, default=0,
|
|
help='When to start updating the EMA weights (default: %(default)s)')
|
|
|
|
### Numerical precision and SV stuff ###
|
|
parser.add_argument(
|
|
'--adam_eps', type=float, default=1e-8,
|
|
help='epsilon value to use for Adam (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--BN_eps', type=float, default=1e-5,
|
|
help='epsilon value to use for BatchNorm (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--SN_eps', type=float, default=1e-8,
|
|
help='epsilon value to use for Spectral Norm(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_G_SVs', type=int, default=1,
|
|
help='Number of SVs to track in G (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_D_SVs', type=int, default=1,
|
|
help='Number of SVs to track in D (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_G_SV_itrs', type=int, default=1,
|
|
help='Number of SV itrs in G (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--num_D_SV_itrs', type=int, default=1,
|
|
help='Number of SV itrs in D (default: %(default)s)')
|
|
|
|
### Ortho reg stuff ###
|
|
parser.add_argument(
|
|
'--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN
|
|
help='Modified ortho reg coefficient in G(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--D_ortho', type=float, default=0.0,
|
|
help='Modified ortho reg coefficient in D (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--toggle_grads', action='store_true', default=True,
|
|
help='Toggle D and G''s "requires_grad" settings when not training them? '
|
|
' (default: %(default)s)')
|
|
|
|
### Which train function ###
|
|
parser.add_argument(
|
|
'--which_train_fn', type=str, default='GAN',
|
|
help='How2trainyourbois (default: %(default)s)')
|
|
|
|
### Resume training stuff
|
|
parser.add_argument(
|
|
'--load_weights', type=str, default='',
|
|
help='Suffix for which weights to load (e.g. best0, copy0) '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--resume', action='store_true', default=False,
|
|
help='Resume training? (default: %(default)s)')
|
|
|
|
### Log stuff ###
|
|
parser.add_argument(
|
|
'--logstyle', type=str, default='%3.3e',
|
|
help='What style to use when logging training metrics?'
|
|
'One of: %#.#f/ %#.#e (float/exp, text),'
|
|
'pickle (python pickle),'
|
|
'npz (numpy zip),'
|
|
'mat (MATLAB .mat file) (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--log_G_spectra', action='store_true', default=False,
|
|
help='Log the top 3 singular values in each SN layer in G? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--log_D_spectra', action='store_true', default=False,
|
|
help='Log the top 3 singular values in each SN layer in D? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sv_log_interval', type=int, default=10,
|
|
help='Iteration interval for logging singular values '
|
|
' (default: %(default)s)')
|
|
|
|
parser.add_argument('--text', type=str)
|
|
|
|
return parser
|
|
|
|
# Arguments for sample.py; not presently used in train.py
|
|
def add_sample_parser(parser):
|
|
parser.add_argument(
|
|
'--sample_npz', action='store_true', default=False,
|
|
help='Sample "sample_num_npz" images and save to npz? '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_num_npz', type=int, default=50000,
|
|
help='Number of images to sample when sampling NPZs '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_sheets', action='store_true', default=False,
|
|
help='Produce class-conditional sample sheets and stick them in '
|
|
'the samples root? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_interps', action='store_true', default=False,
|
|
help='Produce interpolation sheets and stick them in '
|
|
'the samples root? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_sheet_folder_num', type=int, default=-1,
|
|
help='Number to use for the folder for these sample sheets '
|
|
'(default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_random', action='store_true', default=False,
|
|
help='Produce a single random sheet? (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_trunc_curves', type=str, default='',
|
|
help='Get inception metrics with a range of variances?'
|
|
'To use this, specify a startpoint, step, and endpoint, e.g. '
|
|
'--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, '
|
|
'endpoint of 1.0, and stepsize of 1.0. Note that this is '
|
|
'not exactly identical to using tf.truncated_normal, but should '
|
|
'have approximately the same effect. (default: %(default)s)')
|
|
parser.add_argument(
|
|
'--sample_inception_metrics', action='store_true', default=False,
|
|
help='Calculate Inception metrics with sample.py? (default: %(default)s)')
|
|
return parser
|
|
|
|
# Convenience dicts
|
|
dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder,
|
|
'I128': dset.ImageFolder, 'I256': dset.ImageFolder,
|
|
'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5,
|
|
'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5,
|
|
'C10': dset.CIFAR10, 'C100': dset.CIFAR100}
|
|
imsize_dict = {'I32': 32, 'I32_hdf5': 32,
|
|
'I64': 64, 'I64_hdf5': 64,
|
|
'I128': 128, 'I128_hdf5': 128,
|
|
'I256': 256, 'I256_hdf5': 256,
|
|
'C10': 32, 'C100': 32}
|
|
root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5',
|
|
'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5',
|
|
'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5',
|
|
'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5',
|
|
'C10': 'cifar', 'C100': 'cifar'}
|
|
nclass_dict = {'I32': 1000, 'I32_hdf5': 1000,
|
|
'I64': 1000, 'I64_hdf5': 1000,
|
|
'I128': 1000, 'I128_hdf5': 1000,
|
|
'I256': 1000, 'I256_hdf5': 1000,
|
|
'C10': 10, 'C100': 100}
|
|
# Number of classes to put per sample sheet
|
|
classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50,
|
|
'I64': 50, 'I64_hdf5': 50,
|
|
'I128': 20, 'I128_hdf5': 20,
|
|
'I256': 20, 'I256_hdf5': 20,
|
|
'C10': 10, 'C100': 100}
|
|
activation_dict = {'inplace_relu': nn.ReLU(inplace=True),
|
|
'relu': nn.ReLU(inplace=False),
|
|
'ir': nn.ReLU(inplace=True),}
|
|
|
|
class CenterCropLongEdge(object):
|
|
"""Crops the given PIL Image on the long edge.
|
|
Args:
|
|
size (sequence or int): Desired output size of the crop. If size is an
|
|
int instead of sequence like (h, w), a square crop (size, size) is
|
|
made.
|
|
"""
|
|
def __call__(self, img):
|
|
"""
|
|
Args:
|
|
img (PIL Image): Image to be cropped.
|
|
Returns:
|
|
PIL Image: Cropped image.
|
|
"""
|
|
return transforms.functional.center_crop(img, min(img.size))
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__
|
|
|
|
class RandomCropLongEdge(object):
|
|
"""Crops the given PIL Image on the long edge with a random start point.
|
|
Args:
|
|
size (sequence or int): Desired output size of the crop. If size is an
|
|
int instead of sequence like (h, w), a square crop (size, size) is
|
|
made.
|
|
"""
|
|
def __call__(self, img):
|
|
"""
|
|
Args:
|
|
img (PIL Image): Image to be cropped.
|
|
Returns:
|
|
PIL Image: Cropped image.
|
|
"""
|
|
size = (min(img.size), min(img.size))
|
|
# Only step forward along this edge if it's the long edge
|
|
i = (0 if size[0] == img.size[0]
|
|
else np.random.randint(low=0,high=img.size[0] - size[0]))
|
|
j = (0 if size[1] == img.size[1]
|
|
else np.random.randint(low=0,high=img.size[1] - size[1]))
|
|
return transforms.functional.crop(img, i, j, size[0], size[1])
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__
|
|
|
|
|
|
# multi-epoch Dataset sampler to avoid memory leakage and enable resumption of
|
|
# training from the same sample regardless of if we stop mid-epoch
|
|
class MultiEpochSampler(torch.utils.data.Sampler):
|
|
r"""Samples elements randomly over multiple epochs
|
|
|
|
Arguments:
|
|
data_source (Dataset): dataset to sample from
|
|
num_epochs (int) : Number of times to loop over the dataset
|
|
start_itr (int) : which iteration to begin from
|
|
"""
|
|
|
|
def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128):
|
|
self.data_source = data_source
|
|
self.num_samples = len(self.data_source)
|
|
self.num_epochs = num_epochs
|
|
self.start_itr = start_itr
|
|
self.batch_size = batch_size
|
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
|
raise ValueError("num_samples should be a positive integeral "
|
|
"value, but got num_samples={}".format(self.num_samples))
|
|
|
|
def __iter__(self):
|
|
n = len(self.data_source)
|
|
# Determine number of epochs
|
|
num_epochs = int(np.ceil((n * self.num_epochs
|
|
- (self.start_itr * self.batch_size)) / float(n)))
|
|
# Sample all the indices, and then grab the last num_epochs index sets;
|
|
# This ensures if we're starting at epoch 4, we're still grabbing epoch 4's
|
|
# indices
|
|
out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:]
|
|
# Ignore the first start_itr % n indices of the first epoch
|
|
out[0] = out[0][(self.start_itr * self.batch_size % n):]
|
|
# if self.replacement:
|
|
# return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
|
|
# return iter(.tolist())
|
|
output = torch.cat(out).tolist()
|
|
print('Length dataset output is %d' % len(output))
|
|
return iter(output)
|
|
|
|
def __len__(self):
|
|
return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size
|
|
|
|
|
|
# Convenience function to centralize all data loaders
|
|
def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64,
|
|
num_workers=8, shuffle=True, load_in_mem=False, hdf5=False,
|
|
pin_memory=True, drop_last=True, start_itr=0,
|
|
num_epochs=500, use_multiepoch_sampler=False,
|
|
**kwargs):
|
|
|
|
# Append /FILENAME.hdf5 to root if using hdf5
|
|
data_root += '/%s' % root_dict[dataset]
|
|
print('Using dataset root location %s' % data_root)
|
|
|
|
which_dataset = dset_dict[dataset]
|
|
norm_mean = [0.5,0.5,0.5]
|
|
norm_std = [0.5,0.5,0.5]
|
|
image_size = imsize_dict[dataset]
|
|
# For image folder datasets, name of the file where we store the precomputed
|
|
# image locations to avoid having to walk the dirs every time we load.
|
|
dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset}
|
|
|
|
# HDF5 datasets have their own inbuilt transform, no need to train_transform
|
|
if 'hdf5' in dataset:
|
|
train_transform = None
|
|
else:
|
|
if augment:
|
|
print('Data will be augmented...')
|
|
if dataset in ['C10', 'C100']:
|
|
train_transform = [transforms.RandomCrop(32, padding=4),
|
|
transforms.RandomHorizontalFlip()]
|
|
else:
|
|
train_transform = [RandomCropLongEdge(),
|
|
transforms.Resize(image_size),
|
|
transforms.RandomHorizontalFlip()]
|
|
else:
|
|
print('Data will not be augmented...')
|
|
if dataset in ['C10', 'C100']:
|
|
train_transform = []
|
|
else:
|
|
train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)]
|
|
# train_transform = [transforms.Resize(image_size), transforms.CenterCrop]
|
|
train_transform = transforms.Compose(train_transform + [
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(norm_mean, norm_std)])
|
|
train_set = which_dataset(root=data_root, transform=train_transform,
|
|
load_in_mem=load_in_mem, **dataset_kwargs)
|
|
|
|
# Prepare loader; the loaders list is for forward compatibility with
|
|
# using validation / test splits.
|
|
loaders = []
|
|
if use_multiepoch_sampler:
|
|
print('Using multiepoch sampler from start_itr %d...' % start_itr)
|
|
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
|
|
sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size)
|
|
train_loader = DataLoader(train_set, batch_size=batch_size,
|
|
sampler=sampler, **loader_kwargs)
|
|
else:
|
|
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory,
|
|
'drop_last': drop_last} # Default, drop last incomplete batch
|
|
train_loader = DataLoader(train_set, batch_size=batch_size,
|
|
shuffle=shuffle, **loader_kwargs)
|
|
loaders.append(train_loader)
|
|
return loaders
|
|
|
|
|
|
# Utility file to seed rngs
|
|
def seed_rng(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
# Utility to peg all roots to a base root
|
|
# If a base root folder is provided, peg all other root folders to it.
|
|
def update_config_roots(config):
|
|
if config['base_root']:
|
|
print('Pegging all root folders to base root %s' % config['base_root'])
|
|
for key in ['data', 'weights', 'logs', 'samples']:
|
|
config['%s_root' % key] = '%s/%s' % (config['base_root'], key)
|
|
return config
|
|
|
|
|
|
# Utility to prepare root folders if they don't exist; parent folder must exist
|
|
def prepare_root(config):
|
|
for key in ['weights_root', 'logs_root', 'samples_root']:
|
|
if not os.path.exists(config[key]):
|
|
print('Making directory %s for %s...' % (config[key], key))
|
|
os.mkdir(config[key])
|
|
|
|
|
|
# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using
|
|
# the parameters() and buffers() module functions, but for now this works
|
|
# with state_dicts using .copy_
|
|
class ema(object):
|
|
def __init__(self, source, target, decay=0.9999, start_itr=0):
|
|
self.source = source
|
|
self.target = target
|
|
self.decay = decay
|
|
# Optional parameter indicating what iteration to start the decay at
|
|
self.start_itr = start_itr
|
|
# Initialize target's params to be source's
|
|
self.source_dict = self.source.state_dict()
|
|
self.target_dict = self.target.state_dict()
|
|
print('Initializing EMA parameters to be source parameters...')
|
|
with torch.no_grad():
|
|
for key in self.source_dict:
|
|
self.target_dict[key].data.copy_(self.source_dict[key].data)
|
|
# target_dict[key].data = source_dict[key].data # Doesn't work!
|
|
|
|
def update(self, itr=None):
|
|
# If an iteration counter is provided and itr is less than the start itr,
|
|
# peg the ema weights to the underlying weights.
|
|
if itr and itr < self.start_itr:
|
|
decay = 0.0
|
|
else:
|
|
decay = self.decay
|
|
with torch.no_grad():
|
|
for key in self.source_dict:
|
|
self.target_dict[key].data.copy_(self.target_dict[key].data * decay
|
|
+ self.source_dict[key].data * (1 - decay))
|
|
|
|
|
|
# Apply modified ortho reg to a model
|
|
# This function is an optimized version that directly computes the gradient,
|
|
# instead of computing and then differentiating the loss.
|
|
def ortho(model, strength=1e-4, blacklist=[]):
|
|
with torch.no_grad():
|
|
for param in model.parameters():
|
|
# Only apply this to parameters with at least 2 axes, and not in the blacklist
|
|
if len(param.shape) < 2 or any([param is item for item in blacklist]):
|
|
continue
|
|
w = param.view(param.shape[0], -1)
|
|
grad = (2 * torch.mm(torch.mm(w, w.t())
|
|
* (1. - torch.eye(w.shape[0], device=w.device)), w))
|
|
param.grad.data += strength * grad.view(param.shape)
|
|
|
|
|
|
# Default ortho reg
|
|
# This function is an optimized version that directly computes the gradient,
|
|
# instead of computing and then differentiating the loss.
|
|
def default_ortho(model, strength=1e-4, blacklist=[]):
|
|
with torch.no_grad():
|
|
for param in model.parameters():
|
|
# Only apply this to parameters with at least 2 axes & not in blacklist
|
|
if len(param.shape) < 2 or param in blacklist:
|
|
continue
|
|
w = param.view(param.shape[0], -1)
|
|
grad = (2 * torch.mm(torch.mm(w, w.t())
|
|
- torch.eye(w.shape[0], device=w.device), w))
|
|
param.grad.data += strength * grad.view(param.shape)
|
|
|
|
|
|
# Convenience utility to switch off requires_grad
|
|
def toggle_grad(model, on_or_off):
|
|
for param in model.parameters():
|
|
param.requires_grad = on_or_off
|
|
|
|
|
|
# Function to join strings or ignore them
|
|
# Base string is the string to link "strings," while strings
|
|
# is a list of strings or Nones.
|
|
def join_strings(base_string, strings):
|
|
return base_string.join([item for item in strings if item])
|
|
|
|
|
|
# Save a model's weights, optimizer, and the state_dict
|
|
def save_weights(G, D, state_dict, weights_root, experiment_name,
|
|
name_suffix=None, G_ema=None):
|
|
root = '/'.join([weights_root, experiment_name])
|
|
if not os.path.exists(root):
|
|
os.mkdir(root)
|
|
if name_suffix:
|
|
print('Saving weights to %s/%s...' % (root, name_suffix))
|
|
else:
|
|
print('Saving weights to %s...' % root)
|
|
torch.save(G.state_dict(),
|
|
'%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
|
|
torch.save(G.optim.state_dict(),
|
|
'%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))
|
|
torch.save(D.state_dict(),
|
|
'%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))
|
|
torch.save(D.optim.state_dict(),
|
|
'%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))
|
|
torch.save(state_dict,
|
|
'%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))
|
|
if G_ema is not None:
|
|
torch.save(G_ema.state_dict(),
|
|
'%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix])))
|
|
|
|
|
|
# Load a model's weights, optimizer, and the state_dict
|
|
def load_weights(G, D, state_dict, weights_root, experiment_name,
|
|
name_suffix=None, G_ema=None, strict=True, load_optim=True):
|
|
root = '/'.join([weights_root, experiment_name])
|
|
if name_suffix:
|
|
print('Loading %s weights from %s...' % (name_suffix, root))
|
|
else:
|
|
print('Loading weights from %s...' % root)
|
|
if G is not None:
|
|
G.load_state_dict(
|
|
torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))),
|
|
strict=strict)
|
|
if load_optim:
|
|
G.optim.load_state_dict(
|
|
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))))
|
|
if D is not None:
|
|
D.load_state_dict(
|
|
torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))),
|
|
strict=strict)
|
|
if load_optim:
|
|
D.optim.load_state_dict(
|
|
torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))))
|
|
# Load state dict
|
|
for item in state_dict:
|
|
state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item]
|
|
if G_ema is not None:
|
|
G_ema.load_state_dict(
|
|
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))),
|
|
strict=strict)
|
|
|
|
|
|
''' MetricsLogger originally stolen from VoxNet source code.
|
|
Used for logging inception metrics'''
|
|
class MetricsLogger(object):
|
|
def __init__(self, fname, reinitialize=False):
|
|
self.fname = fname
|
|
self.reinitialize = reinitialize
|
|
if os.path.exists(self.fname):
|
|
if self.reinitialize:
|
|
print('{} exists, deleting...'.format(self.fname))
|
|
os.remove(self.fname)
|
|
|
|
def log(self, record=None, **kwargs):
|
|
"""
|
|
Assumption: no newlines in the input.
|
|
"""
|
|
if record is None:
|
|
record = {}
|
|
record.update(kwargs)
|
|
record['_stamp'] = time.time()
|
|
with open(self.fname, 'a') as f:
|
|
f.write(json.dumps(record, ensure_ascii=True) + '\n')
|
|
|
|
|
|
# Logstyle is either:
|
|
# '%#.#f' for floating point representation in text
|
|
# '%#.#e' for exponent representation in text
|
|
# 'npz' for output to npz # NOT YET SUPPORTED
|
|
# 'pickle' for output to a python pickle # NOT YET SUPPORTED
|
|
# 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED
|
|
class MyLogger(object):
|
|
def __init__(self, fname, reinitialize=False, logstyle='%3.3f'):
|
|
self.root = fname
|
|
if not os.path.exists(self.root):
|
|
os.mkdir(self.root)
|
|
self.reinitialize = reinitialize
|
|
self.metrics = []
|
|
self.logstyle = logstyle # One of '%3.3f' or like '%3.3e'
|
|
|
|
# Delete log if re-starting and log already exists
|
|
def reinit(self, item):
|
|
if os.path.exists('%s/%s.log' % (self.root, item)):
|
|
if self.reinitialize:
|
|
# Only print the removal mess
|
|
if 'sv' in item :
|
|
if not any('sv' in item for item in self.metrics):
|
|
print('Deleting singular value logs...')
|
|
else:
|
|
print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item)))
|
|
os.remove('%s/%s.log' % (self.root, item))
|
|
|
|
# Log in plaintext; this is designed for being read in MATLAB(sorry not sorry)
|
|
def log(self, itr, **kwargs):
|
|
for arg in kwargs:
|
|
if arg not in self.metrics:
|
|
if self.reinitialize:
|
|
self.reinit(arg)
|
|
self.metrics += [arg]
|
|
if self.logstyle == 'pickle':
|
|
print('Pickle not currently supported...')
|
|
# with open('%s/%s.log' % (self.root, arg), 'a') as f:
|
|
# pickle.dump(kwargs[arg], f)
|
|
elif self.logstyle == 'mat':
|
|
print('.mat logstyle not currently supported...')
|
|
else:
|
|
with open('%s/%s.log' % (self.root, arg), 'a') as f:
|
|
f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg]))
|
|
|
|
|
|
# Write some metadata to the logs directory
|
|
def write_metadata(logs_root, experiment_name, config, state_dict):
|
|
with open(('%s/%s/metalog.txt' %
|
|
(logs_root, experiment_name)), 'w') as writefile:
|
|
writefile.write('datetime: %s\n' % str(datetime.datetime.now()))
|
|
writefile.write('config: %s\n' % str(config))
|
|
writefile.write('state: %s\n' %str(state_dict))
|
|
|
|
|
|
"""
|
|
Very basic progress indicator to wrap an iterable in.
|
|
|
|
Author: Jan Schlüter
|
|
Andy's adds: time elapsed in addition to ETA, makes it possible to add
|
|
estimated time to 1k iters instead of estimated time to completion.
|
|
"""
|
|
def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'):
|
|
"""
|
|
Returns a generator over `items`, printing the number and percentage of
|
|
items processed and the estimated remaining processing time before yielding
|
|
the next item. `total` gives the total number of items (required if `items`
|
|
has no length), and `min_delay` gives the minimum time in seconds between
|
|
subsequent prints. `desc` gives an optional prefix text (end with a space).
|
|
"""
|
|
total = total or len(items)
|
|
t_start = time.time()
|
|
t_last = 0
|
|
for n, item in enumerate(items):
|
|
t_now = time.time()
|
|
if t_now - t_last > min_delay:
|
|
print("\r%s%d/%d (%6.2f%%)" % (
|
|
desc, n+1, total, n / float(total) * 100), end=" ")
|
|
if n > 0:
|
|
|
|
if displaytype == 's1k': # minutes/seconds for 1000 iters
|
|
next_1000 = n + (1000 - n%1000)
|
|
t_done = t_now - t_start
|
|
t_1k = t_done / n * next_1000
|
|
outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60))
|
|
print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
|
|
else:# displaytype == 'eta':
|
|
t_done = t_now - t_start
|
|
t_total = t_done / n * total
|
|
outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60))
|
|
print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
|
|
|
|
sys.stdout.flush()
|
|
t_last = t_now
|
|
yield item
|
|
t_total = time.time() - t_start
|
|
print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) +
|
|
divmod(t_total, 60)))
|
|
|
|
|
|
# Sample function for use with inception metrics
|
|
def sample(G, z_, y_, config):
|
|
with torch.no_grad():
|
|
z_.sample_()
|
|
y_.sample_()
|
|
if config['parallel']:
|
|
G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_)))
|
|
else:
|
|
G_z = G(z_, G.shared(y_))
|
|
return G_z, y_
|
|
|
|
|
|
# Sample function for sample sheets
|
|
def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel,
|
|
samples_root, experiment_name, folder_number, z_=None):
|
|
# Prepare sample directory
|
|
if not os.path.isdir('%s/%s' % (samples_root, experiment_name)):
|
|
os.mkdir('%s/%s' % (samples_root, experiment_name))
|
|
if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)):
|
|
os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number))
|
|
# loop over total number of sheets
|
|
for i in range(num_classes // classes_per_sheet):
|
|
ims = []
|
|
y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda')
|
|
for j in range(samples_per_class):
|
|
if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0):
|
|
z_.sample_()
|
|
else:
|
|
z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda')
|
|
with torch.no_grad():
|
|
if parallel:
|
|
o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y)))
|
|
else:
|
|
o = G(z_[:classes_per_sheet], G.shared(y))
|
|
|
|
ims += [o.data.cpu()]
|
|
# This line should properly unroll the images
|
|
out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
|
|
ims[0].shape[3]).data.float().cpu()
|
|
#out_ims = torch.from_numpy(out_ims.numpy()) ### NOTE: xcliu for torchvision
|
|
# The path for the samples
|
|
image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name,
|
|
folder_number, i)
|
|
torchvision.utils.save_image(out_ims, image_filename,
|
|
nrow=samples_per_class, normalize=True)
|
|
|
|
|
|
# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..)
|
|
def interp(x0, x1, num_midpoints):
|
|
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype)
|
|
return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
|
|
|
|
|
|
# interp sheet function
|
|
# Supports full, class-wise and intra-class interpolation
|
|
def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel,
|
|
samples_root, experiment_name, folder_number, sheet_number=0,
|
|
fix_z=False, fix_y=False, device='cuda'):
|
|
# Prepare zs and ys
|
|
if fix_z: # If fix Z, only sample 1 z per row
|
|
zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device)
|
|
zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z)
|
|
else:
|
|
zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device),
|
|
torch.randn(num_per_sheet, 1, G.dim_z, device=device),
|
|
num_midpoints).view(-1, G.dim_z)
|
|
if fix_y: # If fix y, only sample 1 z per row
|
|
ys = sample_1hot(num_per_sheet, num_classes)
|
|
ys = G.shared(ys).view(num_per_sheet, 1, -1)
|
|
ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1)
|
|
else:
|
|
ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
|
|
G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
|
|
num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1)
|
|
# Run the net--note that we've already passed y through G.shared.
|
|
if G.fp16:
|
|
zs = zs.half()
|
|
with torch.no_grad():
|
|
if parallel:
|
|
out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu()
|
|
else:
|
|
out_ims = G(zs, ys).data.cpu()
|
|
interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '')
|
|
image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name,
|
|
folder_number, interp_style,
|
|
sheet_number)
|
|
torchvision.utils.save_image(out_ims, image_filename,
|
|
nrow=num_midpoints + 2, normalize=True)
|
|
|
|
|
|
# Convenience debugging function to print out gradnorms and shape from each layer
|
|
# May need to rewrite this so we can actually see which parameter is which
|
|
def print_grad_norms(net):
|
|
gradsums = [[float(torch.norm(param.grad).item()),
|
|
float(torch.norm(param).item()), param.shape]
|
|
for param in net.parameters()]
|
|
order = np.argsort([item[0] for item in gradsums])
|
|
print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0],
|
|
gradsums[item_index][1],
|
|
str(gradsums[item_index][2]))
|
|
for item_index in order])
|
|
|
|
|
|
# Get singular values to log. This will use the state dict to find them
|
|
# and substitute underscores for dots.
|
|
def get_SVs(net, prefix):
|
|
d = net.state_dict()
|
|
return {('%s_%s' % (prefix, key)).replace('.', '_') :
|
|
float(d[key].item())
|
|
for key in d if 'sv' in key}
|
|
|
|
|
|
# Name an experiment based on its config
|
|
def name_from_config(config):
|
|
name = '_'.join([
|
|
item for item in [
|
|
'Big%s' % config['which_train_fn'],
|
|
config['dataset'],
|
|
config['model'] if config['model'] != 'BigGAN' else None,
|
|
'seed%d' % config['seed'],
|
|
'Gch%d' % config['G_ch'],
|
|
'Dch%d' % config['D_ch'],
|
|
'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None,
|
|
'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None,
|
|
'bs%d' % config['batch_size'],
|
|
'Gfp16' if config['G_fp16'] else None,
|
|
'Dfp16' if config['D_fp16'] else None,
|
|
'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None,
|
|
'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None,
|
|
'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None,
|
|
'Glr%2.1e' % config['G_lr'],
|
|
'Dlr%2.1e' % config['D_lr'],
|
|
'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None,
|
|
'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None,
|
|
'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None,
|
|
'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None,
|
|
'Gnl%s' % config['G_nl'],
|
|
'Dnl%s' % config['D_nl'],
|
|
'Ginit%s' % config['G_init'],
|
|
'Dinit%s' % config['D_init'],
|
|
'G%s' % config['G_param'] if config['G_param'] != 'SN' else None,
|
|
'D%s' % config['D_param'] if config['D_param'] != 'SN' else None,
|
|
'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None,
|
|
'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None,
|
|
'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None,
|
|
'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None,
|
|
config['norm_style'] if config['norm_style'] != 'bn' else None,
|
|
'cr' if config['cross_replica'] else None,
|
|
'Gshared' if config['G_shared'] else None,
|
|
'hier' if config['hier'] else None,
|
|
'ema' if config['ema'] else None,
|
|
config['name_suffix'] if config['name_suffix'] else None,
|
|
]
|
|
if item is not None])
|
|
# dogball
|
|
if config['hashname']:
|
|
return hashname(name)
|
|
else:
|
|
return name
|
|
|
|
|
|
# A simple function to produce a unique experiment name from the animal hashes.
|
|
def hashname(name):
|
|
h = hash(name)
|
|
a = h % len(animal_hash.a)
|
|
h = h // len(animal_hash.a)
|
|
b = h % len(animal_hash.b)
|
|
h = h // len(animal_hash.c)
|
|
c = h % len(animal_hash.c)
|
|
return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c]
|
|
|
|
|
|
# Get GPU memory, -i is the index
|
|
def query_gpu(indices):
|
|
os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv')
|
|
|
|
|
|
# Convenience function to count the number of parameters in a module
|
|
def count_parameters(module):
|
|
print('Number of parameters: {}'.format(
|
|
sum([p.data.nelement() for p in module.parameters()])))
|
|
|
|
|
|
# Convenience function to sample an index, not actually a 1-hot
|
|
def sample_1hot(batch_size, num_classes, device='cuda'):
|
|
return torch.randint(low=0, high=num_classes, size=(batch_size,),
|
|
device=device, dtype=torch.int64, requires_grad=False)
|
|
|
|
|
|
# A highly simplified convenience class for sampling from distributions
|
|
# One could also use PyTorch's inbuilt distributions package.
|
|
# Note that this class requires initialization to proceed as
|
|
# x = Distribution(torch.randn(size))
|
|
# x.init_distribution(dist_type, **dist_kwargs)
|
|
# x = x.to(device,dtype)
|
|
# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2
|
|
class Distribution(torch.Tensor):
|
|
# Init the params of the distribution
|
|
def init_distribution(self, dist_type, **kwargs):
|
|
self.dist_type = dist_type
|
|
self.dist_kwargs = kwargs
|
|
if self.dist_type == 'normal':
|
|
self.mean, self.var = kwargs['mean'], kwargs['var']
|
|
elif self.dist_type == 'categorical':
|
|
self.num_categories = kwargs['num_categories']
|
|
|
|
def sample_(self):
|
|
if self.dist_type == 'normal':
|
|
self.normal_(self.mean, self.var)
|
|
elif self.dist_type == 'categorical':
|
|
self.random_(0, self.num_categories)
|
|
# return self.variable
|
|
|
|
# Silly hack: overwrite the to() method to wrap the new object
|
|
# in a distribution as well
|
|
def to(self, *args, **kwargs):
|
|
new_obj = Distribution(self)
|
|
new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
|
|
new_obj.data = super().to(*args, **kwargs)
|
|
return new_obj
|
|
|
|
|
|
# Convenience function to prepare a z and y vector
|
|
def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda',
|
|
fp16=False,z_var=1.0):
|
|
z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
|
|
z_.init_distribution('normal', mean=0, var=z_var)
|
|
z_ = z_.to(device,torch.float16 if fp16 else torch.float32)
|
|
|
|
if fp16:
|
|
z_ = z_.half()
|
|
|
|
y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
|
|
y_.init_distribution('categorical',num_categories=nclasses)
|
|
y_ = y_.to(device, torch.int64)
|
|
return z_, y_
|
|
|
|
|
|
def initiate_standing_stats(net):
|
|
for module in net.modules():
|
|
if hasattr(module, 'accumulate_standing'):
|
|
module.reset_stats()
|
|
module.accumulate_standing = True
|
|
|
|
|
|
def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16):
|
|
initiate_standing_stats(net)
|
|
net.train()
|
|
for i in range(num_accumulations):
|
|
with torch.no_grad():
|
|
z.normal_()
|
|
y.random_(0, nclasses)
|
|
x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn
|
|
# Set to eval mode
|
|
net.eval()
|
|
|
|
|
|
# This version of Adam keeps an fp32 copy of the parameters and
|
|
# does all of the parameter updates in fp32, while still doing the
|
|
# forwards and backwards passes using fp16 (i.e. fp16 copies of the
|
|
# parameters and fp16 activations).
|
|
#
|
|
# Note that this calls .float().cuda() on the params.
|
|
import math
|
|
from torch.optim.optimizer import Optimizer
|
|
class Adam16(Optimizer):
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):
|
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
weight_decay=weight_decay)
|
|
params = list(params)
|
|
super(Adam16, self).__init__(params, defaults)
|
|
|
|
# Safety modification to make sure we floatify our state
|
|
def load_state_dict(self, state_dict):
|
|
super(Adam16, self).load_state_dict(state_dict)
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float()
|
|
self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float()
|
|
self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float()
|
|
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
grad = p.grad.data.float()
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
|
|
# Fp32 copy of the weights
|
|
state['fp32_p'] = p.data.float()
|
|
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
beta1, beta2 = group['betas']
|
|
|
|
state['step'] += 1
|
|
|
|
if group['weight_decay'] != 0:
|
|
grad = grad.add(group['weight_decay'], state['fp32_p'])
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
|
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
|
|
bias_correction1 = 1 - beta1 ** state['step']
|
|
bias_correction2 = 1 - beta2 ** state['step']
|
|
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
|
|
|
state['fp32_p'].addcdiv_(-step_size, exp_avg, denom)
|
|
p.data = state['fp32_p'].half()
|
|
|
|
return loss
|