397 lines
17 KiB
Python
397 lines
17 KiB
Python
"""Utilities for converting TFHub BigGAN generator weights to PyTorch.
|
|
Recommended usage:
|
|
To convert all BigGAN variants and generate test samples, use:
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples
|
|
```
|
|
See `parse_args` for additional options.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import h5py
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision.utils import save_image
|
|
import tensorflow as tf
|
|
import tensorflow_hub as hub
|
|
import parse
|
|
|
|
# import reference biggan from this folder
|
|
import biggan_v1 as biggan_for_conversion
|
|
|
|
# Import model from main folder
|
|
sys.path.append('..')
|
|
import BigGAN
|
|
|
|
|
|
|
|
|
|
DEVICE = 'cuda'
|
|
HDF5_TMPL = 'biggan-{}.h5'
|
|
PTH_TMPL = 'biggan-{}.pth'
|
|
MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2'
|
|
Z_DIMS = {
|
|
128: 120,
|
|
256: 140,
|
|
512: 128}
|
|
RESOLUTIONS = list(Z_DIMS)
|
|
|
|
|
|
def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False):
|
|
"""Loads TFHub weights and saves them to intermediate HDF5 file.
|
|
Args:
|
|
module_path ([Path-like]): Path to TFHub module.
|
|
hdf5_path ([Path-like]): Path to output HDF5 file.
|
|
Returns:
|
|
[h5py.File]: Loaded hdf5 file containing module weights.
|
|
"""
|
|
if os.path.exists(hdf5_path) and (not redownload):
|
|
print('Loading BigGAN hdf5 file from:', hdf5_path)
|
|
return h5py.File(hdf5_path, 'r')
|
|
|
|
print('Loading BigGAN module from:', module_path)
|
|
tf.reset_default_graph()
|
|
hub.Module(module_path)
|
|
print('Loaded BigGAN module from:', module_path)
|
|
|
|
initializer = tf.global_variables_initializer()
|
|
sess = tf.Session()
|
|
sess.run(initializer)
|
|
|
|
print('Saving BigGAN weights to :', hdf5_path)
|
|
h5f = h5py.File(hdf5_path, 'w')
|
|
for var in tf.global_variables():
|
|
val = sess.run(var)
|
|
h5f.create_dataset(var.name, data=val)
|
|
print(f'Saving {var.name} with shape {val.shape}')
|
|
h5f.close()
|
|
return h5py.File(hdf5_path, 'r')
|
|
|
|
|
|
class TFHub2Pytorch(object):
|
|
|
|
TF_ROOT = 'module'
|
|
|
|
NUM_GBLOCK = {
|
|
128: 5,
|
|
256: 6,
|
|
512: 7
|
|
}
|
|
|
|
w = 'w'
|
|
b = 'b'
|
|
u = 'u0'
|
|
v = 'u1'
|
|
gamma = 'gamma'
|
|
beta = 'beta'
|
|
|
|
def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False):
|
|
self.state_dict = state_dict
|
|
self.tf_weights = tf_weights
|
|
self.resolution = resolution
|
|
self.verbose = verbose
|
|
if load_ema:
|
|
for name in ['w', 'b', 'gamma', 'beta']:
|
|
setattr(self, name, getattr(self, name) + '/ema_b999900')
|
|
|
|
def load(self):
|
|
self.load_generator()
|
|
return self.state_dict
|
|
|
|
def load_generator(self):
|
|
GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator')
|
|
|
|
for i in range(self.NUM_GBLOCK[self.resolution]):
|
|
name_tf = os.path.join(GENERATOR_ROOT, 'GBlock')
|
|
name_tf += f'_{i}' if i != 0 else ''
|
|
self.load_GBlock(f'GBlock.{i}.', name_tf)
|
|
|
|
self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention'))
|
|
self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False)
|
|
self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear'))
|
|
self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d'))
|
|
self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN',
|
|
os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN'))
|
|
|
|
def load_linear(self, name_pth, name_tf, bias=True):
|
|
self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0)
|
|
if bias:
|
|
self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b)
|
|
|
|
def load_snlinear(self, name_pth, name_tf, bias=True):
|
|
self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
|
|
self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
|
|
self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0)
|
|
if bias:
|
|
self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b)
|
|
|
|
def load_colorize(self, name_pth, name_tf):
|
|
self.load_snconv(name_pth, name_tf)
|
|
|
|
def load_GBlock(self, name_pth, name_tf):
|
|
self.load_convs(name_pth, name_tf)
|
|
self.load_HyperBNs(name_pth, name_tf)
|
|
|
|
def load_convs(self, name_pth, name_tf):
|
|
self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0'))
|
|
self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1'))
|
|
self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc'))
|
|
|
|
def load_snconv(self, name_pth, name_tf, bias=True):
|
|
if self.verbose:
|
|
print(f'loading: {name_pth} from {name_tf}')
|
|
self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
|
|
self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
|
|
self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1)
|
|
if bias:
|
|
self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze()
|
|
|
|
def load_conv(self, name_pth, name_tf, bias=True):
|
|
|
|
self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
|
|
self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
|
|
self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1)
|
|
if bias:
|
|
self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b)
|
|
|
|
def load_HyperBNs(self, name_pth, name_tf):
|
|
self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN'))
|
|
self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1'))
|
|
|
|
def load_ScaledCrossReplicaBNs(self, name_pth, name_tf):
|
|
self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze()
|
|
self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze()
|
|
self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean')
|
|
self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var')
|
|
self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor(
|
|
self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32)
|
|
|
|
def load_HyperBN(self, name_pth, name_tf):
|
|
if self.verbose:
|
|
print(f'loading: {name_pth} from {name_tf}')
|
|
beta = name_pth + '.beta_embed.module'
|
|
gamma = name_pth + '.gamma_embed.module'
|
|
self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze()
|
|
self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze()
|
|
self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze()
|
|
self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze()
|
|
self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0)
|
|
self.state_dict[gamma +
|
|
'.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0)
|
|
|
|
cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN')
|
|
self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean')
|
|
self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var')
|
|
self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor(
|
|
self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32)
|
|
|
|
def load_attention(self, name_pth, name_tf):
|
|
|
|
self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False)
|
|
self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False)
|
|
self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False)
|
|
self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False)
|
|
self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma)
|
|
|
|
def load_tf_tensor(self, prefix, var, device='0'):
|
|
name = os.path.join(prefix, var) + f':{device}'
|
|
return torch.from_numpy(self.tf_weights[name][:])
|
|
|
|
# Convert from v1: This function maps
|
|
def convert_from_v1(hub_dict, resolution=128):
|
|
weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'}
|
|
convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'}
|
|
attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
|
|
hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight
|
|
# Linear stuff
|
|
'G_linear.module.weight_bar': 'linear.weight',
|
|
'G_linear.module.bias': 'linear.bias',
|
|
'G_linear.module.weight_u': 'linear.u0',
|
|
# output layer stuff
|
|
'ScaledCrossReplicaBN.weight': 'output_layer.0.gain',
|
|
'ScaledCrossReplicaBN.bias': 'output_layer.0.bias',
|
|
'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean',
|
|
'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var',
|
|
'colorize.module.weight_bar': 'output_layer.2.weight',
|
|
'colorize.module.bias': 'output_layer.2.bias',
|
|
'colorize.module.weight_u': 'output_layer.2.u0',
|
|
# Attention stuff
|
|
'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum,
|
|
'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum,
|
|
'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum,
|
|
'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum,
|
|
'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum,
|
|
'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum,
|
|
'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum,
|
|
'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum,
|
|
'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum,
|
|
}
|
|
|
|
# Loop over the hub dict and build the hub2me map
|
|
for name in hub_dict.keys():
|
|
if 'GBlock' in name:
|
|
if 'HyperBN' not in name: # it's a conv
|
|
out = parse.parse('GBlock.{:d}.{}.module.{}',name)
|
|
blocknum, convnum, weightname = out
|
|
if weightname not in weightname_dict:
|
|
continue # else hyperBN in
|
|
out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1
|
|
else: # hyperbn not conv
|
|
BNnum = 2 if 'HyperBN_1' in name else 1
|
|
if 'embed' in name:
|
|
out = parse.parse('GBlock.{:d}.{}.module.{}',name)
|
|
blocknum, gamma_or_beta, weightname = out
|
|
if weightname not in weightname_dict: # Ignore weight_v
|
|
continue
|
|
out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname])
|
|
else:
|
|
out = parse.parse('GBlock.{:d}.{}.bn.{}',name)
|
|
blocknum, dummy, mean_or_var = out
|
|
if 'num_batches_tracked' in mean_or_var:
|
|
continue
|
|
out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var')
|
|
hub2me[name] = out_name
|
|
|
|
|
|
# Invert the hub2me map
|
|
me2hub = {hub2me[item]: item for item in hub2me}
|
|
new_dict = {}
|
|
dimz_dict = {128: 20, 256: 20, 512:16}
|
|
for item in me2hub:
|
|
# Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
|
|
if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item):
|
|
new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1)
|
|
# Reshape the first linear weight, bias, and u0
|
|
elif item == 'linear.weight':
|
|
new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution])
|
|
elif item == 'linear.bias':
|
|
new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1)
|
|
elif item == 'linear.u0':
|
|
new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1)
|
|
elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
|
|
# Transpose shared weight so that it's an embedding
|
|
new_dict[item] = hub_dict[me2hub[item]].t()
|
|
elif 'weight_u' in me2hub[item]: # Unsqueeze u0s
|
|
new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
|
|
else:
|
|
new_dict[item] = hub_dict[me2hub[item]]
|
|
return new_dict
|
|
|
|
def get_config(resolution):
|
|
attn_dict = {128: '64', 256: '128', 512: '64'}
|
|
dim_z_dict = {128: 120, 256: 140, 512: 128}
|
|
config = {'G_param': 'SN', 'D_param': 'SN',
|
|
'G_ch': 96, 'D_ch': 96,
|
|
'D_wide': True, 'G_shared': True,
|
|
'shared_dim': 128, 'dim_z': dim_z_dict[resolution],
|
|
'hier': True, 'cross_replica': False,
|
|
'mybn': False, 'G_activation': nn.ReLU(inplace=True),
|
|
'G_attn': attn_dict[resolution],
|
|
'norm_style': 'bn',
|
|
'G_init': 'ortho', 'skip_init': True, 'no_optim': True,
|
|
'G_fp16': False, 'G_mixed_precision': False,
|
|
'accumulate_stats': False, 'num_standing_accumulations': 16,
|
|
'G_eval_mode': True,
|
|
'BN_eps': 1e-04, 'SN_eps': 1e-04,
|
|
'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution,
|
|
'n_classes': 1000}
|
|
return config
|
|
|
|
|
|
def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False):
|
|
module_path = MODULE_PATH_TMPL.format(resolution)
|
|
hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution))
|
|
pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution))
|
|
|
|
tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload)
|
|
G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')()
|
|
state_dict_temp = G_temp.state_dict()
|
|
|
|
converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution,
|
|
load_ema=(not no_ema), verbose=verbose)
|
|
state_dict_v1 = converter.load()
|
|
state_dict = convert_from_v1(state_dict_v1, resolution)
|
|
# Get the config, build the model
|
|
config = get_config(resolution)
|
|
G = BigGAN.Generator(**config)
|
|
G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
|
|
torch.save(state_dict, pth_path)
|
|
|
|
# output_location ='pretrained_weights/TFHub-PyTorch-128.pth'
|
|
|
|
return G
|
|
|
|
|
|
def generate_sample(G, z_dim, batch_size, filename, parallel=False):
|
|
|
|
G.eval()
|
|
G.to(DEVICE)
|
|
with torch.no_grad():
|
|
z = torch.randn(batch_size, G.dim_z).to(DEVICE)
|
|
y = torch.randint(low=0, high=1000, size=(batch_size,),
|
|
device=DEVICE, dtype=torch.int64, requires_grad=False)
|
|
if parallel:
|
|
images = nn.parallel.data_parallel(G, (z, G.shared(y)))
|
|
else:
|
|
images = G(z, G.shared(y))
|
|
save_image(images, filename, scale_each=True, normalize=True)
|
|
|
|
def parse_args():
|
|
usage = 'Parser for conversion script.'
|
|
parser = argparse.ArgumentParser(description=usage)
|
|
parser.add_argument(
|
|
'--resolution', '-r', type=int, default=None, choices=[128, 256, 512],
|
|
help='Resolution of TFHub module to convert. Converts all resolutions if None.')
|
|
parser.add_argument(
|
|
'--redownload', action='store_true', default=False,
|
|
help='Redownload weights and overwrite current hdf5 file, if present.')
|
|
parser.add_argument(
|
|
'--weights_dir', type=str, default='pretrained_weights')
|
|
parser.add_argument(
|
|
'--samples_dir', type=str, default='pretrained_samples')
|
|
parser.add_argument(
|
|
'--no_ema', action='store_true', default=False,
|
|
help='Do not load ema weights.')
|
|
parser.add_argument(
|
|
'--verbose', action='store_true', default=False,
|
|
help='Additionally logging.')
|
|
parser.add_argument(
|
|
'--generate_samples', action='store_true', default=False,
|
|
help='Generate test sample with pretrained model.')
|
|
parser.add_argument(
|
|
'--batch_size', type=int, default=64,
|
|
help='Batch size used for test sample.')
|
|
parser.add_argument(
|
|
'--parallel', action='store_true', default=False,
|
|
help='Parallelize G?')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = parse_args()
|
|
os.makedirs(args.weights_dir, exist_ok=True)
|
|
os.makedirs(args.samples_dir, exist_ok=True)
|
|
|
|
if args.resolution is not None:
|
|
G = convert_biggan(args.resolution, args.weights_dir,
|
|
redownload=args.redownload,
|
|
no_ema=args.no_ema, verbose=args.verbose)
|
|
if args.generate_samples:
|
|
filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg')
|
|
print('Generating samples...')
|
|
generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel)
|
|
else:
|
|
for res in RESOLUTIONS:
|
|
G = convert_biggan(res, args.weights_dir,
|
|
redownload=args.redownload,
|
|
no_ema=args.no_ema, verbose=args.verbose)
|
|
if args.generate_samples:
|
|
filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg')
|
|
print('Generating samples...')
|
|
generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel)
|