From c8b8388df52607802e6e3a9e3c5aa07ac9773fbc Mon Sep 17 00:00:00 2001 From: zhaojinghao Date: Wed, 3 Aug 2022 10:16:48 +0800 Subject: [PATCH] initial commit --- Dockerfile | 17 + README.md | 7 + ...iative_embedding_hrnet_w32_coco_512x512.py | 1134 ++++++++++++++++ backbone/backbone_infer.py | 13 + detection/detection.py | 22 + mnist/MNIST_cnn.py | 71 + ocr/ocr.py | 22 + requirements.txt | 17 + run.py | 143 ++ segmentation/segment_pred.py | 66 + text2image/BigGAN_utils/BigGAN.py | 484 +++++++ text2image/BigGAN_utils/BigGANdeep.py | 534 ++++++++ text2image/BigGAN_utils/LICENSE | 21 + text2image/BigGAN_utils/README.md | 1 + .../__pycache__/biggan_v1.cpython-38.pyc | Bin 0 -> 11044 bytes text2image/BigGAN_utils/TFHub/biggan_v1.py | 389 ++++++ text2image/BigGAN_utils/TFHub/converter.py | 396 ++++++ text2image/BigGAN_utils/__init__.py | 2 + .../__pycache__/BigGAN.cpython-37.pyc | Bin 0 -> 15764 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 221 bytes .../__pycache__/animal_hash.cpython-37.pyc | Bin 0 -> 35920 bytes .../__pycache__/datasets.cpython-37.pyc | Bin 0 -> 10075 bytes .../__pycache__/layers.cpython-37.pyc | Bin 0 -> 12676 bytes .../__pycache__/utils.cpython-37.pyc | Bin 0 -> 35369 bytes text2image/BigGAN_utils/animal_hash.py | 439 ++++++ text2image/BigGAN_utils/binary_utils.py | 14 + .../calculate_inception_moments.py | 91 ++ text2image/BigGAN_utils/datasets.py | 362 +++++ text2image/BigGAN_utils/inception_tf13.py | 138 ++ text2image/BigGAN_utils/inception_utils.py | 310 +++++ text2image/BigGAN_utils/layers.py | 459 +++++++ .../logs/BigGAN_ch96_bs256x8.jsonl | 68 + text2image/BigGAN_utils/logs/compare_IS.m | 89 ++ text2image/BigGAN_utils/logs/metalog.txt | 3 + .../BigGAN_utils/logs/process_inception_log.m | 19 + .../BigGAN_utils/logs/process_training.m | 109 ++ text2image/BigGAN_utils/losses.py | 33 + text2image/BigGAN_utils/make_hdf5.py | 110 ++ text2image/BigGAN_utils/sample.py | 183 +++ .../scripts/launch_BigGAN_bs256x8.sh | 17 + .../scripts/launch_BigGAN_bs512x4.sh | 17 + .../scripts/launch_BigGAN_ch64_bs256x8.sh | 17 + .../scripts/launch_BigGAN_deep.sh | 18 + .../scripts/launch_SAGAN_bs128x2_ema.sh | 13 + .../BigGAN_utils/scripts/launch_SNGAN.sh | 14 + .../BigGAN_utils/scripts/launch_cifar_ema.sh | 11 + .../scripts/sample_BigGAN_bs256x8.sh | 20 + .../BigGAN_utils/scripts/sample_cifar_ema.sh | 11 + .../BigGAN_utils/scripts/utils/duplicate.sh | 14 + .../scripts/utils/prepare_data.sh | 3 + .../BigGAN_utils/sync_batchnorm/__init__.py | 12 + .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 348 bytes .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 383 bytes .../__pycache__/batchnorm.cpython-37.pyc | Bin 0 -> 13007 bytes .../__pycache__/batchnorm.cpython-38.pyc | Bin 0 -> 13039 bytes .../__pycache__/comm.cpython-37.pyc | Bin 0 -> 4733 bytes .../__pycache__/comm.cpython-38.pyc | Bin 0 -> 4819 bytes .../__pycache__/replicate.cpython-37.pyc | Bin 0 -> 3409 bytes .../__pycache__/replicate.cpython-38.pyc | Bin 0 -> 3468 bytes .../BigGAN_utils/sync_batchnorm/batchnorm.py | 349 +++++ .../sync_batchnorm/batchnorm_reimpl.py | 74 + .../BigGAN_utils/sync_batchnorm/comm.py | 137 ++ .../BigGAN_utils/sync_batchnorm/replicate.py | 94 ++ .../BigGAN_utils/sync_batchnorm/unittest.py | 29 + text2image/BigGAN_utils/train.py | 227 ++++ text2image/BigGAN_utils/train_fns.py | 187 +++ text2image/BigGAN_utils/utils.py | 1195 +++++++++++++++++ text2image/BigGAN_utils/weights/README.md | 2 + text2image/DiffAugment_pytorch.py | 102 ++ text2image/LICENSE | 21 + text2image/README.md | 65 + text2image/__init__.py | 0 text2image/fusedream_generator.py | 35 + text2image/fusedream_utils.py | 308 +++++ text2image/run_text2img.py | 30 + 75 files changed, 8788 insertions(+) create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 backbone/associative_embedding_hrnet_w32_coco_512x512.py create mode 100644 backbone/backbone_infer.py create mode 100644 detection/detection.py create mode 100644 mnist/MNIST_cnn.py create mode 100644 ocr/ocr.py create mode 100644 requirements.txt create mode 100644 run.py create mode 100644 segmentation/segment_pred.py create mode 100644 text2image/BigGAN_utils/BigGAN.py create mode 100644 text2image/BigGAN_utils/BigGANdeep.py create mode 100644 text2image/BigGAN_utils/LICENSE create mode 100644 text2image/BigGAN_utils/README.md create mode 100644 text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc create mode 100644 text2image/BigGAN_utils/TFHub/biggan_v1.py create mode 100644 text2image/BigGAN_utils/TFHub/converter.py create mode 100644 text2image/BigGAN_utils/__init__.py create mode 100644 text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/__pycache__/__init__.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/__pycache__/animal_hash.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/animal_hash.py create mode 100644 text2image/BigGAN_utils/binary_utils.py create mode 100644 text2image/BigGAN_utils/calculate_inception_moments.py create mode 100644 text2image/BigGAN_utils/datasets.py create mode 100644 text2image/BigGAN_utils/inception_tf13.py create mode 100644 text2image/BigGAN_utils/inception_utils.py create mode 100644 text2image/BigGAN_utils/layers.py create mode 100644 text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl create mode 100644 text2image/BigGAN_utils/logs/compare_IS.m create mode 100644 text2image/BigGAN_utils/logs/metalog.txt create mode 100644 text2image/BigGAN_utils/logs/process_inception_log.m create mode 100644 text2image/BigGAN_utils/logs/process_training.m create mode 100644 text2image/BigGAN_utils/losses.py create mode 100644 text2image/BigGAN_utils/make_hdf5.py create mode 100644 text2image/BigGAN_utils/sample.py create mode 100644 text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_SNGAN.sh create mode 100644 text2image/BigGAN_utils/scripts/launch_cifar_ema.sh create mode 100644 text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh create mode 100644 text2image/BigGAN_utils/scripts/sample_cifar_ema.sh create mode 100644 text2image/BigGAN_utils/scripts/utils/duplicate.sh create mode 100644 text2image/BigGAN_utils/scripts/utils/prepare_data.sh create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__init__.py create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc create mode 100644 text2image/BigGAN_utils/sync_batchnorm/batchnorm.py create mode 100644 text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py create mode 100644 text2image/BigGAN_utils/sync_batchnorm/comm.py create mode 100644 text2image/BigGAN_utils/sync_batchnorm/replicate.py create mode 100644 text2image/BigGAN_utils/sync_batchnorm/unittest.py create mode 100644 text2image/BigGAN_utils/train.py create mode 100644 text2image/BigGAN_utils/train_fns.py create mode 100644 text2image/BigGAN_utils/utils.py create mode 100644 text2image/BigGAN_utils/weights/README.md create mode 100644 text2image/DiffAugment_pytorch.py create mode 100644 text2image/LICENSE create mode 100644 text2image/README.md create mode 100644 text2image/__init__.py create mode 100644 text2image/fusedream_generator.py create mode 100644 text2image/fusedream_utils.py create mode 100644 text2image/run_text2img.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..be9f28c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.7.13-slim + +WORKDIR /app +ADD . /app/ + +RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir +RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir +RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir +RUN mim install mmcv-full +RUN pip install -e ./text2image/CLIP/ +RUN pip install -e ./backbone/mmpose/ +RUN pip install -e ./ocr/tr/ +RUN rm -rf ./text2image/CLIP/ +RUN rm -rf ./ocr/tr/ +RUN rm -rf ./backbone/mmpose/ + +CMD ["python3", "run.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1c03ad --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +## 机器视觉模块 +### backbone: 关键点检测 +### detection: 目标检测 +### mnist: mnist手写体识别 +### ocr: 光学字符识别 +### segmentation: (光伏)图像分割 +### text2image: (文本)图像生成 \ No newline at end of file diff --git a/backbone/associative_embedding_hrnet_w32_coco_512x512.py b/backbone/associative_embedding_hrnet_w32_coco_512x512.py new file mode 100644 index 0000000..4774d01 --- /dev/null +++ b/backbone/associative_embedding_hrnet_w32_coco_512x512.py @@ -0,0 +1,1134 @@ +checkpoint_config = dict(interval=50) +log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) +log_level = 'INFO' +load_from = None +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] +opencv_num_threads = 0 +mp_start_method = 'fork' +dataset_info = dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict(name='nose', id=0, color=[51, 153, 255], type='upper', swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]), + 3: + dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]), + 4: + dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]), + 5: + dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]), + 6: + dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]), + 10: + dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]), + 11: + dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]), + 12: + dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]), + 16: + dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]), + 17: + dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, 1.0, 1.2, + 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, + 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ]) +evaluation = dict(interval=50, metric='mAP', save_best='AP') +optimizer = dict(type='Adam', lr=0.0015) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[200, 260]) +total_epochs = 300 +channel_cfg = dict( + dataset_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) +data_cfg = dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False) +model = dict( + type='AssociativeEmbedding', + pretrained= + 'https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w32-36af842e.pth', + backbone=dict( + type='HRNet', + in_channels=3, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256)))), + keypoint_head=dict( + type='AESimpleHead', + in_channels=32, + num_joints=17, + num_deconv_layers=0, + tag_per_joint=True, + with_ae_loss=[True], + extra=dict(final_conv_kernel=1), + loss_keypoint=dict( + type='MultiLossFactory', + num_joints=17, + num_stages=1, + ae_loss_type='exp', + with_ae_loss=[True], + push_loss_factor=[0.001], + pull_loss_factor=[0.001], + with_heatmaps_loss=[True], + heatmaps_loss_factor=[1.0])), + train_cfg=dict(), + test_cfg=dict( + num_joints=17, + max_num_people=30, + scale_factor=[1], + with_heatmaps=[True], + with_ae=[True], + project2image=True, + align_corners=False, + nms_kernel=5, + nms_padding=2, + tag_per_joint=True, + detection_threshold=0.1, + tag_threshold=1, + use_detection_val=True, + ignore_too_much=False, + adjust=True, + refine=True, + flip_test=True)) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='BottomUpRandomAffine', + rot_factor=30, + scale_factor=[0.75, 1.5], + scale_type='short', + trans_factor=40), + dict(type='BottomUpRandomFlip', flip_prob=0.5), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='BottomUpGenerateTarget', sigma=2, max_num_people=30), + dict( + type='Collect', + keys=['img', 'joints', 'targets', 'masks'], + meta_keys=[]) +] +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) +] +data_root = 'data/coco' +data = dict( + workers_per_gpu=2, + train_dataloader=dict(samples_per_gpu=24), + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_train2017.json', + img_prefix='data/coco/train2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='BottomUpRandomAffine', + rot_factor=30, + scale_factor=[0.75, 1.5], + scale_type='short', + trans_factor=40), + dict(type='BottomUpRandomFlip', flip_prob=0.5), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='BottomUpGenerateTarget', sigma=2, max_num_people=30), + dict( + type='Collect', + keys=['img', 'joints', 'targets', 'masks'], + meta_keys=[]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ])), + val=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_val2017.json', + img_prefix='data/coco/val2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ])), + test=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_val2017.json', + img_prefix='data/coco/val2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ]))) diff --git a/backbone/backbone_infer.py b/backbone/backbone_infer.py new file mode 100644 index 0000000..d4e2432 --- /dev/null +++ b/backbone/backbone_infer.py @@ -0,0 +1,13 @@ +from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result +import cv2 +import os + +def run_backbone_infer(img_ndarr, pose_model): + # test a single image + pose_results, _ = inference_bottom_up_pose_model(pose_model, img_ndarr) + rst = vis_pose_result(pose_model, img_ndarr, pose_results, dataset='TopDownCocoWholeBodyDataset', thickness=2, radius=8, bbox_color='white') + return pose_results, rst + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/detection/detection.py b/detection/detection.py new file mode 100644 index 0000000..33750f8 --- /dev/null +++ b/detection/detection.py @@ -0,0 +1,22 @@ +import torch +import matplotlib.pyplot as plt + +def detector(img, model): + """_summary_ + + Args: + img (str or numpy.ndarray): 图片路径或者像素矩阵 + model (_type_): 预加载的模型 + + Returns: + rtn(numpy.ndarray): 渲染后的图片像素点 + pred(numpy.ndarray): 检测而出的目标的坐标点、置信度和类别,shape=[n, 6] + """ + result = model(img) + + return result.render()[0], result.pred[0].cpu().numpy() + +if __name__ == '__main__': + # model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True) + pass + \ No newline at end of file diff --git a/mnist/MNIST_cnn.py b/mnist/MNIST_cnn.py new file mode 100644 index 0000000..dbf0803 --- /dev/null +++ b/mnist/MNIST_cnn.py @@ -0,0 +1,71 @@ +import os.path + +import numpy as np +import tensorflow as tf +physical_devices = tf.config.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(physical_devices[0], True) + +num_classes = 10 +input_shape = (28, 28, 1) + +def pre_process(): + + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = tf.keras.utils.to_categorical(y_train, num_classes) + y_test = tf.keras.utils.to_categorical(y_test, num_classes) + + return x_train, y_train, x_test, y_test + + +def createModel(neure, kernel_size, pool_size, activation): + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(input_shape)), + tf.keras.layers.Conv2D(neure, kernel_size=(kernel_size, kernel_size), activation=activation), + tf.keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size)), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + + model.summary() + + return model + +def trainModel(model, x_train, y_train, epochs, loss): + batch_size = 128 + epochs = epochs + + model.compile(loss=loss, optimizer="adam", metrics=["accuracy"]) + + model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) + + return model + +# 模型预测 +def predictModel(x_test, model): + predicted_data = model.predict(x_test) + return predicted_data + +def train(neure, kernel_size, pool_size, activation, epochs, loss): + x_train, y_train, x_test, y_test = pre_process() + model = createModel(neure, kernel_size, pool_size, activation) + model = trainModel(model, x_train, y_train, epochs, loss) + score = model.evaluate(x_test, y_test, verbose=0) + accuracy = score[1] + + model.save(os.path.abspath("./appweb/self_model/mnist_cnn.h5")) + + return accuracy \ No newline at end of file diff --git a/ocr/ocr.py b/ocr/ocr.py new file mode 100644 index 0000000..f9ac5b0 --- /dev/null +++ b/ocr/ocr.py @@ -0,0 +1,22 @@ +import tr +import cv2 + +def run_tr(image): + """_summary_ + + Args: + image (np.ndarray): 像素矩阵 + + Returns: + list: 字组成的列表 + """ + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rst = tr.run(gray_image) + if len(rst) == 0: + return [] + return [x[1] for x in rst] + + +if __name__ == '__main__': + pass + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ebeb26a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +Flask==2.1.0 +h5py +ftfy==6.1.1 +logzero==1.7.0 +lpips==0.1.4 +numpy==1.21.6 +pandas==1.3.5 +parse==1.19.0 +Pillow==9.2.0 +scipy==1.4.1 +six==1.15.0 +tqdm==4.64.0 +opencv-python==4.6.0.66 +seaborn==0.11.2 +segmentation-models-pytorch==0.2.1 +albumentations==1.2.1 +openmim==0.2.0 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..f48a0ab --- /dev/null +++ b/run.py @@ -0,0 +1,143 @@ +# -*-coding:utf-8-*- +import os +import sys +from logzero import logger +current_path = os.path.dirname(__file__) +logger.info(current_path) +sys.path.append(f"{current_path}/text2image/") +sys.path.append(f"{current_path}/text2image/BigGAN_utils/") +import json +import base64 +from flask import Flask, request, make_response +import cv2 +# from io import BytesIO +# import torch +from mmpose.apis import init_pose_model +# from text2image.run_text2img import text2image +# from detection.detection import detector +# from segmentation.segment_pred import run_seg +# from ocr.ocr import run_tr +from backbone.backbone_infer import run_backbone_infer + +DEVICE = 'cpu' + +# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True) +# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True) +# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) +pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py' +pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth' +pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0' + +app=Flask(__name__) + +# @app.route('/text2image/',methods=["POST"]) +# def run_text2img(): +# if request.method == "POST": +# text = request.form.get('text') +# logger.info(f"{text}") +# img = text2image(text) +# output_buffer = BytesIO() +# img.save(output_buffer, format='png') +# byte_data = output_buffer.getvalue() +# b64_code = base64.b64encode(byte_data).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return resp +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + + +# @app.route('/detection/', methods=["POST"]) +# def run_detection(): +# if request.method == "POST": +# img = request.files.get('image') +# model_type = request.form.get('model_type') +# try: +# img = cv2.imread(img) +# except: +# resp = make_response() +# resp.status_code = 406 +# return resp +# if model_type.lower().strip() == 'yolov5x': +# rst, _ = detector(img, model_5x) +# else: +# rst, _ = detector(img, model_5s) +# logger.info(rst.shape) +# img_str = cv2.imencode('.png', rst)[1].tobytes() +# b64_code = base64.b64encode(img_str).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return b64_code +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + +# @app.route('/ocr/', methods=["POST"]) +# def run_ocr(): +# resp = make_response() +# if request.method == "POST": +# img = request.files.get('image') +# try: +# img = cv2.imread(img) +# except: +# resp.status_code = 406 +# return resp +# text = run_tr(img) +# resp.status_code = 200 +# resp.data = json.dumps({'result':text}) +# return resp +# else: +# resp.status_code=405 +# return resp + + +# @app.route('/segmentation/', methods=["POST"]) +# def run_segmentation(): +# if request.method == "POST": +# img_upload = request.files.get('image') +# try: +# img = cv2.imread(img_upload) +# except: +# resp = make_response() +# resp.status_code = 406 +# return resp +# result = run_seg(img, model_seg) +# img_str = cv2.imencode('.png', result)[1].tobytes() +# b64_code = base64.b64encode(img_str).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return resp +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + +@app.route('/backbone/', methods=["POST"]) +def run_backbone(): + if request.method == "POST": + img_upload = request.files.get('image') + try: + img = cv2.imread(img_upload) + except: + resp = make_response() + resp.status_code = 406 + return resp + pose, result = run_backbone_infer(img, pose_model) + img_str = cv2.imencode('.png', result)[1].tobytes() + b64_code = base64.b64encode(img_str).decode('utf-8') + resp = make_response(b64_code) + resp.status_code = 200 + return resp + else: + resp = make_response() + resp.status_code=405 + return resp + + +if __name__ == '__main__': + img = cv2.imread('./1.jpg') + pose, rst = run_backbone_infer(img, pose_model) + cv2.imwrite('./1_bb.jpg', rst) \ No newline at end of file diff --git a/segmentation/segment_pred.py b/segmentation/segment_pred.py new file mode 100644 index 0000000..075b9f6 --- /dev/null +++ b/segmentation/segment_pred.py @@ -0,0 +1,66 @@ +from asyncio.log import logger +import os +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import albumentations as albu +import torch +import segmentation_models_pytorch as smp +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as BaseDataset +from PIL import Image + +DEVICE = 'cpu' +ENCODER = 'se_resnext50_32x4d' +ENCODER_WEIGHTS = 'imagenet' + +# --------------------------------------------------------------- +### 加载数据 + +def get_validation_augmentation(): + """调整图像使得图片的分辨率长宽能被32整除""" + test_transform = [ + albu.PadIfNeeded(256, 256) + ] + return albu.Compose(test_transform) + + +def to_tensor(x, **kwargs): + return x.transpose(2, 0, 1).astype('float32') + + +def get_preprocessing(preprocessing_fn): + """进行图像预处理操作 + + Args: + preprocessing_fn (callbale): 数据规范化的函数 + (针对每种预训练的神经网络) + Return: + transform: albumentations.Compose + """ + + _transform = [ + albu.Lambda(image=preprocessing_fn), + albu.Lambda(image=to_tensor), + ] + return albu.Compose(_transform) + + +def run_seg(img, best_model): + # 测试集 + preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) + augmentator = get_validation_augmentation() + preprocessor = get_preprocessing(preprocessing_fn) + # --------------------------------------------------------------- + img = augmentator(image=img)['image'] + img = preprocessor(image=img)['image'] + # 加载最佳模型 + x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0) + pr_mask = best_model.predict(x_tensor) + pr_mask = (pr_mask.squeeze().cpu().numpy().round()) + return (pr_mask - 1) * (-220) + + +if __name__ == '__main__': + best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) + input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp') \ No newline at end of file diff --git a/text2image/BigGAN_utils/BigGAN.py b/text2image/BigGAN_utils/BigGAN.py new file mode 100644 index 0000000..3367ba0 --- /dev/null +++ b/text2image/BigGAN_utils/BigGAN.py @@ -0,0 +1,484 @@ +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +import layers +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d + + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample' : [True] * 7, + 'resolution' : [8, 16, 32, 64, 128, 256, 512], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,10)}} + arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample' : [True] * 6, + 'resolution' : [8, 16, 32, 64, 128, 256], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,9)}} + arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample' : [True] * 5, + 'resolution' : [8, 16, 32, 64, 128], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,8)}} + arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2]], + 'upsample' : [True] * 4, + 'resolution' : [8, 16, 32, 64], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,7)}} + arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], + 'out_channels' : [ch * item for item in [4, 4, 4]], + 'upsample' : [True] * 3, + 'resolution' : [8, 16, 32], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,6)}} + + return arch + +class Generator(nn.Module): + def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + # If using hierarchical latents, adjust z + if self.hier: + # Number of places z slots into + self.num_slots = len(self.arch['in_channels']) + 1 + self.z_chunk_size = (self.dim_z // self.num_slots) + # Recalculate latent dimensionality for even splitting into chunks + self.dim_z = self.z_chunk_size * self.num_slots + else: + self.num_slots = 1 + self.z_chunk_size = 0 + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.z_chunk_size if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z // self.num_slots, + self.arch['in_channels'][0] * (self.bottom_width **2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] else None))]] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print(self.arch['resolution'], self.arch['attention']) + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward(self, z, y, w_y=None): + if w_y is not None: + s_y = torch.softmax(w_y, dim=1) + + cur_y = s_y * y + y = cur_y.sum(dim=1, keepdim=False) + + # If hierarchical, concatenate zs and ys + if self.hier: + zs = torch.split(z, self.z_chunk_size, 1) + z = zs[0] + ys = [torch.cat([y, item], 1) for item in zs[1:]] + else: + ys = [y] * len(self.blocks) + + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, ys[index]) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward_org(self, z, y): + # If hierarchical, concatenate zs and ys + if self.hier: + zs = torch.split(z, self.z_chunk_size, 1) + z = zs[0] + ys = [torch.cat([y, item], 1) for item in zs[1:]] + else: + ys = [y] * len(self.blocks) + + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, ys[index]) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample' : [True] * 6 + [False], + 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample' : [True] * 5 + [False], + 'resolution' : [64, 32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample' : [True] * 4 + [False], + 'resolution' : [32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,7)}} + arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], + 'out_channels' : [item * ch for item in [4, 4, 4, 4]], + 'downsample' : [True, True, False, False], + 'resolution' : [16, 16, 16, 16], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,6)}} + return arch + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=128, + D_kernel_size=3, D_attn='64', n_classes=1000, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, + SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN', **kwargs): + super(Discriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Set up optimizer + self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps + if D_mixed_precision: + print('Using fp16 adam in D...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out diff --git a/text2image/BigGAN_utils/BigGANdeep.py b/text2image/BigGAN_utils/BigGANdeep.py new file mode 100644 index 0000000..95763c3 --- /dev/null +++ b/text2image/BigGAN_utils/BigGANdeep.py @@ -0,0 +1,534 @@ +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +import layers +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d + +# BigGAN-deep: uses a different resblock and pattern + + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. + +# Channel ratio is the ratio of +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=layers.bn, activation=None, + upsample=None, channel_ratio=4): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.hidden_channels = self.in_channels // channel_ratio + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, + kernel_size=1, padding=0) + self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(self.in_channels) + self.bn2 = self.which_bn(self.hidden_channels) + self.bn3 = self.which_bn(self.hidden_channels) + self.bn4 = self.which_bn(self.hidden_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + # Project down to channel ratio + h = self.conv1(self.activation(self.bn1(x, y))) + # Apply next BN-ReLU + h = self.activation(self.bn2(h, y)) + # Drop channels in x if necessary + if self.in_channels != self.out_channels: + x = x[:, :self.out_channels] + # Upsample both h and x at this point + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + # 3x3 convs + h = self.conv2(h) + h = self.conv3(self.activation(self.bn3(h, y))) + # Final 1x1 conv + h = self.conv4(self.activation(self.bn4(h, y))) + return h + x + +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample' : [True] * 6, + 'resolution' : [8, 16, 32, 64, 128, 256], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,9)}} + arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample' : [True] * 5, + 'resolution' : [8, 16, 32, 64, 128], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,8)}} + arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2]], + 'upsample' : [True] * 4, + 'resolution' : [8, 16, 32, 64], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,7)}} + arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], + 'out_channels' : [ch * item for item in [4, 4, 4]], + 'upsample' : [True] * 3, + 'resolution' : [8, 16, 32], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,6)}} + + return arch + +class Generator(nn.Module): + def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Number of resblocks per stage + self.G_depth = G_depth + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.dim_z if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))] + for g_index in range(self.G_depth)] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + # NOTE: The z vs y dichotomy here is for compatibility with not-y + def forward(self, z, y): + # If hierarchical, concatenate zs and ys + if self.hier: + z = torch.cat([y, z], 1) + y = z + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, y) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=layers.SNConv2d, wide=True, + preactivation=True, activation=None, downsample=None, + channel_ratio=4): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels // channel_ratio + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, + kernel_size=1, padding=0) + self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, + kernel_size=1, padding=0) + + self.learnable_sc = True if (in_channels != out_channels) else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels - in_channels, + kernel_size=1, padding=0) + def shortcut(self, x): + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = torch.cat([x, self.conv_sc(x)], 1) + return x + + def forward(self, x): + # 1x1 bottleneck conv + h = self.conv1(F.relu(x)) + # 3x3 convs + h = self.conv2(self.activation(h)) + h = self.conv3(self.activation(h)) + # relu before downsample + h = self.activation(h) + # downsample + if self.downsample: + h = self.downsample(h) + # final 1x1 conv + h = self.conv4(h) + return h + self.shortcut(x) + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16]], + 'out_channels' : [item * ch for item in [2, 4, 8, 8, 16, 16]], + 'downsample' : [True] * 6 + [False], + 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[128] = {'in_channels' : [item * ch for item in [1, 2, 4, 8, 16]], + 'out_channels' : [item * ch for item in [2, 4, 8, 16, 16]], + 'downsample' : [True] * 5 + [False], + 'resolution' : [64, 32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[64] = {'in_channels' : [item * ch for item in [1, 2, 4, 8]], + 'out_channels' : [item * ch for item in [2, 4, 8, 16]], + 'downsample' : [True] * 4 + [False], + 'resolution' : [32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,7)}} + arch[32] = {'in_channels' : [item * ch for item in [4, 4, 4]], + 'out_channels' : [item * ch for item in [4, 4, 4]], + 'downsample' : [True, True, False, False], + 'resolution' : [16, 16, 16, 16], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,6)}} + return arch + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=128, + D_kernel_size=3, D_attn='64', n_classes=1000, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, + SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN', **kwargs): + super(Discriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # How many resblocks per stage? + self.D_depth = D_depth + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + + + # Prepare model + # Stem convolution + self.input_conv = self.which_conv(3, self.arch['in_channels'][0]) + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index] if d_index==0 else self.arch['out_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=True, + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index==0 else None)) + for d_index in range(self.D_depth)]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Set up optimizer + self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps + if D_mixed_precision: + print('Using fp16 adam in D...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Run input conv + h = self.input_conv(x) + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out diff --git a/text2image/BigGAN_utils/LICENSE b/text2image/BigGAN_utils/LICENSE new file mode 100644 index 0000000..c2dd721 --- /dev/null +++ b/text2image/BigGAN_utils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Andy Brock + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/text2image/BigGAN_utils/README.md b/text2image/BigGAN_utils/README.md new file mode 100644 index 0000000..3830c36 --- /dev/null +++ b/text2image/BigGAN_utils/README.md @@ -0,0 +1 @@ +Download pre-trained weights from (https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing) and put them in `./weights/` diff --git a/text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc b/text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1101bce69984299c3a2951e8e48a921eaac61dab GIT binary patch literal 11044 zcmeHNTaX;rS?>Gv^vuq#R;!gP$+oa%@Z=&{N~};|Vq0=nDNq*URj{*|bY`b}XJ<7x zpYB;n%WNR59CK5OP~nnLz$_F6;$Bpdf>6X&T=K{hPyIxN2eZWs51}Xmh55e!%yqSL z8NwU%RR4X>e=gm9&i{Y^rRU=_Gl7BQ&CkD4`P_NK_y<-d4+oX!a3=o@5*d-%Fjh^u zTB{bWR>N*OtB%PwyJ0omRaeSR!)xYNb5eFu_E+;LKN964|BksDM2|*6R6wZ^6{DFu z#%l4Z5tX9ZJ4Q4cyPGq+_G$^7x#$QuM+Tf(aOR_<;2a%r=D=Bqj)8M*z&QfWVsspw z;{(n-I47cqzN>0!b+pQ9T$0Lc_dTKp!^n29Do08 zm76>~JQR4!b6!HBXGg{jUlk*x6{N{~Cw`cX- z1-wn!S>FH8x8HtypEoR9nUmlNvZAWAqINS((n=baT`XiL^E59jyd-TaY*HMR*=j2( zds(nnNvj*7OyF*ZQN6hu$Zk})-k3R#qs)~lGjFTX=*EdGnkIbfY<|!`j>7FQsWdx{ zm|KwWMGfC{<0?2L~U43nL400)-x=|2QQ6l(A#>%s%8jGWRU~svM{|rD6&D z=v5oIO^7!)Z)JA7n;v{5FN(P{i|C%n`kO^leH&_3oKweecajf-9gF%w^=95M)=fNu zb0@!QM{eXni5yWcW7UoFQGmJ^LA`Lzy@5Ab_3@~sthmyNtEs9qmT~vYyBXJO8>xC2 zQ%{KjZ)y+D`=>e?Z&WH zcDU`0b@lC-Qhf)>lOXq?lwpXKPm(YkZ@+p1ZQS34!eE#s)03ZXZXcUmjdR0Co00<` zL--yd2Y-Y@3Q5?+`)^v5hh1yOoZwi(u_idSaO|GFNriD9GIp?HeHRMj?pZr->S3nb zuK9xT!%yybsh{TiLC@_M2n>7HuDN3kc|A+&=4-D&`ikf?1C9w*%%!E_6JalM{aGsL zDPz~xyHNI)Gb`-k+{&DAt)fIm_o19`m*=uvEr!OWYP^)06IId*Hf^KQ!R~C;$buc2n-X%chn2M1ZDeMe*-<;qtmamhOZ1FcKEBm~PKTAG?5p=;=U8WjCidvYV#yvJp8eP^#3v5vZ3;$h&EL!id}z;}wy7 zthqXe`4w!cp`HS{zEpP9Mb<_#o0Xo&!q?ZjY5WnT+At%sN(UA}l`)6|lM;wwTKe@N z!6kG1q3PG2P?)Lrru@1%<-Pw7h4EWcetV|KrU}`!AX@g8bgKU|JtwqLcBSm1EMp+slznqEpe}unXjQ+!a^xL1dd{8& zoT0(dyc~GYm?@qQo<|#y`pCRa~nlD2|!soZiZ3ImvUJ#JyRM(hWrn7adr2 zZ;CF`yiY*YdrXk(MNmNp(rww?REofr-N%| z)-tahR?$;bSbc~i-R)SllgzDEn$1drgV@Dzkcb5j)QPOS4zL2gEx7O+$VYfcPrwNML%Dgu z6=!*=Y-e6&eZAg_6)`;Xn+SYcX}5XLR4;>34#r@S6~(*74vsd6=j z2g|&(XvkMl7;8vsY3+a}*o5|8PCXWxEkHLk#S^V?2Yu)R4Ern{=;K9z5{-87Q`q8mSvuP+p~o_9t^<=*m|8g@Jt8p`1Y(8 zjKfu%h)FrUf-|8!^hBvQ4Yeq$jqC~Hf2z4osG+oJs}H8|J&|&wWaw5;(rs*oK|${v z@739UGMwqjDPG~#d_?iqsC8NI-7piOV4ta1}ox`nXjopysMz0a`2PmugL z3AX_r-c~K7!v@>4OD@-cfW?_xZ*{t9mWTXkTP3iB#1BO4Fhu%oj7{DLV!&Q{43{KE zCS^Gw52rKYmMq8IE)Di@h*8MG`UjzjFQZx{PmmD)9&EZljaI%(!qlGcdI;F{&(YU! zl+G&lsp6bEC@egtO}U0A=J%NABTOOBH!|fxR=gU+R93JP7cPG1KCylOA8n9Np|mt* z!f9&=QfM`j$Dq%X0JRCBqR`Py0}h;?qs~Uo4l;{kU$il|U}X>pki%3I`T>k01mQ?y zS$(9gbnn3!WP+poU1!AAmTJtEdIUj*{Tg$n9u*F7r5+V0xiwe96J$E^96`z?L4Mba zibx3n-qIO8hdgB|wKiw<=x8P?-6i17!40tx*vOpGEPpT$#b|C2YUx1d$nyS=VY!wB z3DpXU`&5Gc1rp?ZjQ!In-LuqPTr{ezv&2psH@WI?UBr#}QChHk&wZnPEh^DxSSv|u3d${2{RmD$yeR^~?WTDSH<;QSgo zir~T%vgV;!julvomXFfn@b~Y&PZ(v+Zu{7z*d7j`Tr#fS6;Az^AgM${*k?#8A{xjk zA|d)jgfU1{WIbevhvW`o7DAr4=k#-sYZuauG9(ppDgbM|j{fKo>_V22e*LajtsM_% z4sjF{Q-vFr`m_KfBn{&XBwTYn2U0?K%)^+vy*$!q?rWb!uM%P+_L}X_=@NKH20Wub z&l0}zt1iZxu~1K7DEbXE5W$GfknXi?$J(p8a1!)5B0S^6w&f zIzZL|h7NcdK@a0W!ilPbeo`AE?SMChv*gvh0iGh7pTN_?2v23i1fFV4K8UB9JH*q% zeR!&V2@5xYr|OGbkmo(gk-FWvXIxnn}gK9 zdv)jHjn)0nIp@1k_phMiH*gNs{gfi7igW7tR@MFK3l~K)C!T3a-D!wQIK_GT-v6n( z!;Pb2PB>5QJ2+3dL!Bqhol<=`T9aJF0t1&ybEO^)AYs&Z3KT|M4GTiWDZ>o4+yEFD zJ&ql%U<3>@{}33EIhN5fX6R(;Sqq4+3oxHSRGf0MW|j{J1>%AF6_RN*IH3OOci80+ zMSq;~`5rXBgHJ-=2D?7d;j{Ooy#Q!V1Hn~*x=l@&E&cz!mJ0X8Uu^w#O{{kHu znGBWv5k=0sRCZqyIs8z=NC|PuRHT&8W?rZ#RaI}+Ta1<_o~Jql58O*Wlg{EQihgHG z(Ger?bC3kZ>NKJwrfQ5((l4ot6pzP*QqNkzhhX_&f{!*x&E)osL91@NzlR?opN~u) z5;g4EAJ60M>7;2ODIRAqyPm;Bbj4(=uK99b;iZL0YEVf>mj-k=LiY{EBhXLR@C^V@ zKkTP#hw%zi{d7&!BA;rHNhf?t!c#L2;g{70DimOPm&NpSTz+dcq-{9LOJoo$9g zf9ilMB~u*&v)8#ymoNiiHn1HN85V?n$d=)YnExN1cLFtfRe|=cH2qhvxM=Z@9NUYB zFYVNRbFX<`!I?0Lgy?FRCubaJ!gIE{|NOG`A!ffNB1A-HqXjt(mhiV5m|sSNFiyOJ z8#xk5+WcFC^x^j~@cRKUS=oVrrw-ff~}76k@9Qc5C1{H_)LA7JcWVKYf-|mlM;#{FtfIgO)vTotL+I`d@xLZ6_c)TdIBf%@Mje{ z&@FmN_}2g(rRDNs<|01r+*U_X)nRQRsn$EU&+)%{iTuY;f3xxB&qU{1E&Z`poPAle z8IELzMgAF=c{=r>6+y(CIb$7@G6)pcP)vqA(8Q{YB7n?)M*r>L!+iGS42W;~ND#`; Z4xGUA-T6!NkIm1`U!Ffde{TN5e*v_mL?QqH literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/TFHub/biggan_v1.py b/text2image/BigGAN_utils/TFHub/biggan_v1.py new file mode 100644 index 0000000..56e99b8 --- /dev/null +++ b/text2image/BigGAN_utils/TFHub/biggan_v1.py @@ -0,0 +1,389 @@ +# BigGAN V1: +# This is now deprecated code used for porting the TFHub modules to pytorch, +# included here for reference only. +import numpy as np +import torch +from scipy.stats import truncnorm +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F + + +def l2normalize(v, eps=1e-4): + return v / (v.norm() + eps) + + +def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) + return truncation * values + + +def denorm(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) + + +class SpectralNorm(nn.Module): + def __init__(self, module, name='weight', power_iterations=1): + super(SpectralNorm, self).__init__() + self.module = module + self.name = name + self.power_iterations = power_iterations + if not self._made_params(): + self._make_params() + + def _update_u_v(self): + u = getattr(self.module, self.name + "_u") + v = getattr(self.module, self.name + "_v") + w = getattr(self.module, self.name + "_bar") + + height = w.data.shape[0] + _w = w.view(height, -1) + for _ in range(self.power_iterations): + v = l2normalize(torch.matmul(_w.t(), u)) + u = l2normalize(torch.matmul(_w, v)) + + sigma = u.dot((_w).mv(v)) + setattr(self.module, self.name, w / sigma.expand_as(w)) + + def _made_params(self): + try: + getattr(self.module, self.name + "_u") + getattr(self.module, self.name + "_v") + getattr(self.module, self.name + "_bar") + return True + except AttributeError: + return False + + def _make_params(self): + w = getattr(self.module, self.name) + + height = w.data.shape[0] + width = w.view(height, -1).data.shape[1] + + u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + u.data = l2normalize(u.data) + v.data = l2normalize(v.data) + w_bar = Parameter(w.data) + + del self.module._parameters[self.name] + self.module.register_parameter(self.name + "_u", u) + self.module.register_parameter(self.name + "_v", v) + self.module.register_parameter(self.name + "_bar", w_bar) + + def forward(self, *args): + self._update_u_v() + return self.module.forward(*args) + + +class SelfAttention(nn.Module): + """ Self Attention Layer""" + + def __init__(self, in_dim, activation=F.relu): + super().__init__() + self.chanel_in = in_dim + self.activation = activation + + self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) + self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) + self.pool = nn.MaxPool2d(2, 2) + self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False)) + self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False)) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + m_batchsize, C, width, height = x.size() + N = height * width + + theta = self.theta(x) + phi = self.phi(x) + phi = self.pool(phi) + phi = phi.view(m_batchsize, -1, N // 4) + theta = theta.view(m_batchsize, -1, N) + theta = theta.permute(0, 2, 1) + attention = self.softmax(torch.bmm(theta, phi)) + g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4) + attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height) + out = self.o_conv(attn_g) + return self.gamma * out + x + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) + self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) + self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) + + def forward(self, x, y): + out = self.bn(x) + gamma = self.gamma_embed(y) + 1 + beta = self.beta_embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) + return out + + +class GBlock(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size=[3, 3], + padding=1, + stride=1, + n_class=None, + bn=True, + activation=F.relu, + upsample=True, + downsample=False, + z_dim=148, + ): + super().__init__() + + self.conv0 = SpectralNorm( + nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) + ) + self.conv1 = SpectralNorm( + nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) + ) + + self.skip_proj = False + if in_channel != out_channel or upsample or downsample: + self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) + self.skip_proj = True + + self.upsample = upsample + self.downsample = downsample + self.activation = activation + self.bn = bn + if bn: + self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim) + self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim) + + def forward(self, input, condition=None): + out = input + + if self.bn: + out = self.HyperBN(out, condition) + out = self.activation(out) + if self.upsample: + out = F.interpolate(out, scale_factor=2) + out = self.conv0(out) + if self.bn: + out = self.HyperBN_1(out, condition) + out = self.activation(out) + out = self.conv1(out) + + if self.downsample: + out = F.avg_pool2d(out, 2) + + if self.skip_proj: + skip = input + if self.upsample: + skip = F.interpolate(skip, scale_factor=2) + skip = self.conv_sc(skip) + if self.downsample: + skip = F.avg_pool2d(skip, 2) + else: + skip = input + return out + skip + + +class Generator128(nn.Module): + def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) + + z_dim = code_dim + 28 + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), + GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), + GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), + GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + ]) + + self.sa_id = 4 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(2 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Generator256(nn.Module): + def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class), + GBlock(16 * chn, 8 * chn, n_class=n_class), + GBlock(8 * chn, 8 * chn, n_class=n_class), + GBlock(8 * chn, 4 * chn, n_class=n_class), + GBlock(4 * chn, 2 * chn, n_class=n_class), + GBlock(2 * chn, 1 * chn, n_class=n_class), + ]) + + self.sa_id = 5 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(2 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Generator512(nn.Module): + def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn)) + + z_dim = code_dim + 16 + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), + GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), + GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), + GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + ]) + + self.sa_id = 4 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(4 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Discriminator(nn.Module): + def __init__(self, n_class=1000, chn=96, debug=False): + super().__init__() + + def conv(in_channel, out_channel, downsample=True): + return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample) + + if debug: + chn = 8 + self.debug = debug + + self.pre_conv = nn.Sequential( + SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)), + nn.ReLU(), + SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)), + nn.AvgPool2d(2), + ) + self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1)) + + self.conv = nn.Sequential( + conv(1 * chn, 1 * chn, downsample=True), + conv(1 * chn, 2 * chn, downsample=True), + SelfAttention(2 * chn), + conv(2 * chn, 2 * chn, downsample=True), + conv(2 * chn, 4 * chn, downsample=True), + conv(4 * chn, 8 * chn, downsample=True), + conv(8 * chn, 8 * chn, downsample=True), + conv(8 * chn, 16 * chn, downsample=True), + conv(16 * chn, 16 * chn, downsample=False), + ) + + self.linear = SpectralNorm(nn.Linear(16 * chn, 1)) + + self.embed = nn.Embedding(n_class, 16 * chn) + self.embed.weight.data.uniform_(-0.1, 0.1) + self.embed = SpectralNorm(self.embed) + + def forward(self, input, class_id): + + out = self.pre_conv(input) + out += self.pre_skip(F.avg_pool2d(input, 2)) + out = self.conv(out) + out = F.relu(out) + out = out.view(out.size(0), out.size(1), -1) + out = out.sum(2) + out_linear = self.linear(out).squeeze(1) + embed = self.embed(class_id) + + prod = (out * embed).sum(1) + + return out_linear + prod diff --git a/text2image/BigGAN_utils/TFHub/converter.py b/text2image/BigGAN_utils/TFHub/converter.py new file mode 100644 index 0000000..a0aa322 --- /dev/null +++ b/text2image/BigGAN_utils/TFHub/converter.py @@ -0,0 +1,396 @@ +"""Utilities for converting TFHub BigGAN generator weights to PyTorch. +Recommended usage: +To convert all BigGAN variants and generate test samples, use: +```bash +CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples +``` +See `parse_args` for additional options. +""" + +import argparse +import os +import sys + +import h5py +import torch +import torch.nn as nn +from torchvision.utils import save_image +import tensorflow as tf +import tensorflow_hub as hub +import parse + +# import reference biggan from this folder +import biggan_v1 as biggan_for_conversion + +# Import model from main folder +sys.path.append('..') +import BigGAN + + + + +DEVICE = 'cuda' +HDF5_TMPL = 'biggan-{}.h5' +PTH_TMPL = 'biggan-{}.pth' +MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2' +Z_DIMS = { + 128: 120, + 256: 140, + 512: 128} +RESOLUTIONS = list(Z_DIMS) + + +def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False): + """Loads TFHub weights and saves them to intermediate HDF5 file. + Args: + module_path ([Path-like]): Path to TFHub module. + hdf5_path ([Path-like]): Path to output HDF5 file. + Returns: + [h5py.File]: Loaded hdf5 file containing module weights. + """ + if os.path.exists(hdf5_path) and (not redownload): + print('Loading BigGAN hdf5 file from:', hdf5_path) + return h5py.File(hdf5_path, 'r') + + print('Loading BigGAN module from:', module_path) + tf.reset_default_graph() + hub.Module(module_path) + print('Loaded BigGAN module from:', module_path) + + initializer = tf.global_variables_initializer() + sess = tf.Session() + sess.run(initializer) + + print('Saving BigGAN weights to :', hdf5_path) + h5f = h5py.File(hdf5_path, 'w') + for var in tf.global_variables(): + val = sess.run(var) + h5f.create_dataset(var.name, data=val) + print(f'Saving {var.name} with shape {val.shape}') + h5f.close() + return h5py.File(hdf5_path, 'r') + + +class TFHub2Pytorch(object): + + TF_ROOT = 'module' + + NUM_GBLOCK = { + 128: 5, + 256: 6, + 512: 7 + } + + w = 'w' + b = 'b' + u = 'u0' + v = 'u1' + gamma = 'gamma' + beta = 'beta' + + def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False): + self.state_dict = state_dict + self.tf_weights = tf_weights + self.resolution = resolution + self.verbose = verbose + if load_ema: + for name in ['w', 'b', 'gamma', 'beta']: + setattr(self, name, getattr(self, name) + '/ema_b999900') + + def load(self): + self.load_generator() + return self.state_dict + + def load_generator(self): + GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator') + + for i in range(self.NUM_GBLOCK[self.resolution]): + name_tf = os.path.join(GENERATOR_ROOT, 'GBlock') + name_tf += f'_{i}' if i != 0 else '' + self.load_GBlock(f'GBlock.{i}.', name_tf) + + self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention')) + self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False) + self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear')) + self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d')) + self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN', + os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN')) + + def load_linear(self, name_pth, name_tf, bias=True): + self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) + if bias: + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_snlinear(self, name_pth, name_tf, bias=True): + self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) + if bias: + self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_colorize(self, name_pth, name_tf): + self.load_snconv(name_pth, name_tf) + + def load_GBlock(self, name_pth, name_tf): + self.load_convs(name_pth, name_tf) + self.load_HyperBNs(name_pth, name_tf) + + def load_convs(self, name_pth, name_tf): + self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0')) + self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1')) + self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc')) + + def load_snconv(self, name_pth, name_tf, bias=True): + if self.verbose: + print(f'loading: {name_pth} from {name_tf}') + self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) + if bias: + self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze() + + def load_conv(self, name_pth, name_tf, bias=True): + + self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) + if bias: + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_HyperBNs(self, name_pth, name_tf): + self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN')) + self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1')) + + def load_ScaledCrossReplicaBNs(self, name_pth, name_tf): + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze() + self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze() + self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean') + self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var') + self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor( + self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32) + + def load_HyperBN(self, name_pth, name_tf): + if self.verbose: + print(f'loading: {name_pth} from {name_tf}') + beta = name_pth + '.beta_embed.module' + gamma = name_pth + '.gamma_embed.module' + self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze() + self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze() + self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze() + self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze() + self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0) + self.state_dict[gamma + + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0) + + cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN') + self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean') + self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var') + self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor( + self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32) + + def load_attention(self, name_pth, name_tf): + + self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False) + self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False) + self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False) + self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False) + self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma) + + def load_tf_tensor(self, prefix, var, device='0'): + name = os.path.join(prefix, var) + f':{device}' + return torch.from_numpy(self.tf_weights[name][:]) + +# Convert from v1: This function maps +def convert_from_v1(hub_dict, resolution=128): + weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'} + convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'} + attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution] + hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight + # Linear stuff + 'G_linear.module.weight_bar': 'linear.weight', + 'G_linear.module.bias': 'linear.bias', + 'G_linear.module.weight_u': 'linear.u0', + # output layer stuff + 'ScaledCrossReplicaBN.weight': 'output_layer.0.gain', + 'ScaledCrossReplicaBN.bias': 'output_layer.0.bias', + 'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean', + 'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var', + 'colorize.module.weight_bar': 'output_layer.2.weight', + 'colorize.module.bias': 'output_layer.2.bias', + 'colorize.module.weight_u': 'output_layer.2.u0', + # Attention stuff + 'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum, + 'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum, + 'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum, + 'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum, + 'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum, + 'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum, + 'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum, + 'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum, + 'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum, + } + + # Loop over the hub dict and build the hub2me map + for name in hub_dict.keys(): + if 'GBlock' in name: + if 'HyperBN' not in name: # it's a conv + out = parse.parse('GBlock.{:d}.{}.module.{}',name) + blocknum, convnum, weightname = out + if weightname not in weightname_dict: + continue # else hyperBN in + out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1 + else: # hyperbn not conv + BNnum = 2 if 'HyperBN_1' in name else 1 + if 'embed' in name: + out = parse.parse('GBlock.{:d}.{}.module.{}',name) + blocknum, gamma_or_beta, weightname = out + if weightname not in weightname_dict: # Ignore weight_v + continue + out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname]) + else: + out = parse.parse('GBlock.{:d}.{}.bn.{}',name) + blocknum, dummy, mean_or_var = out + if 'num_batches_tracked' in mean_or_var: + continue + out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var') + hub2me[name] = out_name + + + # Invert the hub2me map + me2hub = {hub2me[item]: item for item in hub2me} + new_dict = {} + dimz_dict = {128: 20, 256: 20, 512:16} + for item in me2hub: + # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs + if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item): + new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1) + # Reshape the first linear weight, bias, and u0 + elif item == 'linear.weight': + new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution]) + elif item == 'linear.bias': + new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1) + elif item == 'linear.u0': + new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1) + elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER + # Transpose shared weight so that it's an embedding + new_dict[item] = hub_dict[me2hub[item]].t() + elif 'weight_u' in me2hub[item]: # Unsqueeze u0s + new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0) + else: + new_dict[item] = hub_dict[me2hub[item]] + return new_dict + +def get_config(resolution): + attn_dict = {128: '64', 256: '128', 512: '64'} + dim_z_dict = {128: 120, 256: 140, 512: 128} + config = {'G_param': 'SN', 'D_param': 'SN', + 'G_ch': 96, 'D_ch': 96, + 'D_wide': True, 'G_shared': True, + 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], + 'hier': True, 'cross_replica': False, + 'mybn': False, 'G_activation': nn.ReLU(inplace=True), + 'G_attn': attn_dict[resolution], + 'norm_style': 'bn', + 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, + 'G_fp16': False, 'G_mixed_precision': False, + 'accumulate_stats': False, 'num_standing_accumulations': 16, + 'G_eval_mode': True, + 'BN_eps': 1e-04, 'SN_eps': 1e-04, + 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, + 'n_classes': 1000} + return config + + +def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False): + module_path = MODULE_PATH_TMPL.format(resolution) + hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution)) + pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution)) + + tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload) + G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')() + state_dict_temp = G_temp.state_dict() + + converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution, + load_ema=(not no_ema), verbose=verbose) + state_dict_v1 = converter.load() + state_dict = convert_from_v1(state_dict_v1, resolution) + # Get the config, build the model + config = get_config(resolution) + G = BigGAN.Generator(**config) + G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries + torch.save(state_dict, pth_path) + + # output_location ='pretrained_weights/TFHub-PyTorch-128.pth' + + return G + + +def generate_sample(G, z_dim, batch_size, filename, parallel=False): + + G.eval() + G.to(DEVICE) + with torch.no_grad(): + z = torch.randn(batch_size, G.dim_z).to(DEVICE) + y = torch.randint(low=0, high=1000, size=(batch_size,), + device=DEVICE, dtype=torch.int64, requires_grad=False) + if parallel: + images = nn.parallel.data_parallel(G, (z, G.shared(y))) + else: + images = G(z, G.shared(y)) + save_image(images, filename, scale_each=True, normalize=True) + +def parse_args(): + usage = 'Parser for conversion script.' + parser = argparse.ArgumentParser(description=usage) + parser.add_argument( + '--resolution', '-r', type=int, default=None, choices=[128, 256, 512], + help='Resolution of TFHub module to convert. Converts all resolutions if None.') + parser.add_argument( + '--redownload', action='store_true', default=False, + help='Redownload weights and overwrite current hdf5 file, if present.') + parser.add_argument( + '--weights_dir', type=str, default='pretrained_weights') + parser.add_argument( + '--samples_dir', type=str, default='pretrained_samples') + parser.add_argument( + '--no_ema', action='store_true', default=False, + help='Do not load ema weights.') + parser.add_argument( + '--verbose', action='store_true', default=False, + help='Additionally logging.') + parser.add_argument( + '--generate_samples', action='store_true', default=False, + help='Generate test sample with pretrained model.') + parser.add_argument( + '--batch_size', type=int, default=64, + help='Batch size used for test sample.') + parser.add_argument( + '--parallel', action='store_true', default=False, + help='Parallelize G?') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + + args = parse_args() + os.makedirs(args.weights_dir, exist_ok=True) + os.makedirs(args.samples_dir, exist_ok=True) + + if args.resolution is not None: + G = convert_biggan(args.resolution, args.weights_dir, + redownload=args.redownload, + no_ema=args.no_ema, verbose=args.verbose) + if args.generate_samples: + filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg') + print('Generating samples...') + generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel) + else: + for res in RESOLUTIONS: + G = convert_biggan(res, args.weights_dir, + redownload=args.redownload, + no_ema=args.no_ema, verbose=args.verbose) + if args.generate_samples: + filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg') + print('Generating samples...') + generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel) diff --git a/text2image/BigGAN_utils/__init__.py b/text2image/BigGAN_utils/__init__.py new file mode 100644 index 0000000..493d56b --- /dev/null +++ b/text2image/BigGAN_utils/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.append('./BigGAN_utils/') diff --git a/text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6ab4b23dd52380c2849b99f14ff75c70416527 GIT binary patch literal 15764 zcmds8TaX*sS?-%!Esf?b&uz!{_`dSm=K% z?pfPI4`i{Ebs$ch3y?q_W|9yRNJuIL6crwTz+);8sZ;^|KtVzk8F--J1u8G_eWzP$ zX*^yhTNzl8s(+n6m;c=U^Pm6x-_gb1-n4?x%EDjIZ@g1ce!!R3$3Wm&JnmnC*ov*z zl)Nf`wY(;O^}H^Bjl3a$6Lm9hs?4j^67^(0DZ^^bs;Ba)D4edFnxZQ8o_t2e=*Zcd z?~79V^Zh7o)CTH<`60v&+bKJJTg?y8Dt3>Zxvkh4$E;>=>G@rD*6y|YkT!y}etQ6E z15w&2Vg~IY#0*6-W43xl$?e)iPn5Zw8W@#E#lvsva#0lPj^_xxj9$6fC@qR+qq6GQ z7mHqLajq%qr|r#WQ7X{RJQ0|W%THjQlAjWXh@S)=4y2C9{feq66=ltMRb5bR)z)rX zwmzfWQsuiL-xGYd&9Z7Ivx;r8lk6TlR8GFCpd347d9oX~G?{z&vkBj@`_3woPge5T z)#f`Q-$&bg$GZ8JP04k<&2^%SYj4E$5y@q;&2_hY-_z#0x07o*A$d-!Mu4v)4u=Hhobrh_exWnJ<++U7Ub#V>tU z@ePQ@V{;`;nfz#=oYY}FZWg4h_{#e-Yw8`vSF_5;wQDHwWksZrl+yxj`XofjbLx;D z4Wh&7AHDU&Vzcg?SY0eOtBWVDH^q{>Tr4>!O3nK63M8%Ya>cDQ8z(&HhIhJBFP5DX z7c1r23v-1PuTpcPm+9r3KT4miRa}qd&h27@*g8{LotP~YMQL&RY^_-;*4%Rlw6f)r zn*+!s`aw1=hWNV4E^V^(COi6*^;UK&>ad`Yiyn}iDl!bS$D)^{4`c*S)XH2^HbT}P zl0oJgBH`#ta&zH&8fet-mu0!s-E?`ks3$O)Q8pu!0{yavVOtHV%IkOk$A zvaZ7> zo(|nNPj;mgPG`rI1!@7(lW>=7&`2ToK&?sco3C%lpzvr9&uIbHX_l^wn?S5 z$%)G)f}TpFP+BZD8cxj(GR+k)ezR7V-C}*Y<^*ZsxXl`b5j9Vu+A2+*kh_ zSsWgVyR#$W5cA(baX21#BBu+f7}{j#B!Ol~q@B7dc_UBvskqHMGTkwMmDY_YHp)&e z5!#I0N`l0ai(^(`*_B$6mPTOErj=_S%UM&`4POVv0TtBr z6QD`o1hxDmXv$B6ru`mUT{6Vm{ho#Z~wm!;~am1i))$s|w zO(4W)Jj*wuk+IqOahv^N{Ge?tY4=od*n6Zp>FuuW!Sw6C<)@IAwavTi+1@B^svV!a z%iiy6Ov;q~)dMVLTX)HOueZN?kg51Rgx|y9ltvBn%!}|ox&ZEK4Yu*{+MbB_5q}Tf zF*;R@jm-a1{Py6P7*%9mwrB{?Fa(h!{isi_f9xykhdJ`qo(-*Gdb`(i8|>*(#P36F zdnCe9vQ>^=HDPClm8`PC`FxD$-omjPKlF}AquI;RyoDpx`@Xxj*N4DXk6~n{QBIDu zzgO}YwUI9hBUklNJeX1B=?mlUY2w|ocPFYRv2gm)XZn%`?!p;TR^NoPUg4O`KN#^F z3d@ZvC~0`7WE$hh6M1fFs^XophuiUJ#|T;q=|l_dY(g1V(Ermiy`ya#8f(*)v*3Hp zcuDd0uRm^cp~`YRIBrlHjv@XD^z+G9c{?lXmUCi{a&Fnbb@o;Eo#}RBL)+l2pYg}< z>FznCze%o_iF>+utBu2^Kj07gL;moB@m|WzQ+VEt=V?4|@kfyBbLi6;djS=s)XuI= z`BT+rz~>z2v5Ndvlj=iEVReT`@)HbE z&+w$UkFfr=;GKzEj1Y_^ZH>=PtUl7S|;1BlZg z)KA8IrRe6&Aai}OQd%t3Dh;P7g5Hv7x^6)@IQy20fl%W}t&EW(}U z01B)gyCC~G)joDl)rvPAF;!_y%_4FtJ{3*U48}+TJp0cj` z5T#B>0^CLoG~^vkd!^>Q0MAWi<_3v*ICPd`^EW=peT{r0ufpN>4l!LQ(2ZFrZ0iZq z#8W6G-VCxj`YsoQ>daKpF4kqArl+Sd&*m~cpl+bm1RCIYMfe4-JC*XHSFoK@ z@n$%k7vP{+S$2dt26lnfs#xylLBg)o3#&n9zUg_*df|G-_7(+)G?W3k-KXGps1;q; zf#}Q@+{L1B>>wTf7EnGg7As)iZn8Mf^6YYuoGr*Uins91yf{~Y#0BP+@Yi6cBO$ zDW}{OJjdkC*^3l_K)pmi4LpU*P1r_GkXROY4N@1{{-BHGn!e&(TOr?KO`K!|{W)AZ zvOhsa%0@Vqu6Ujqw2PkybU6m*1w@=W6Qt!p(;e=@CRp}3FHFkq3k=t( zEd<6aMpL{U%*4--aF+AEvxQD&L1WGH8sY_3DQ9A~aOK@@(920-Q0fiw4wj!I`B{?o z%m*2iZ!c2T2+0IR+euH08neHX3xy63$lH)c$#A~}qV)7> z=9cf!miI6{tD9O@?b8M{Q#JJrOjrZo@V5^smSL%5s)=|_rewqPL{?2}ed=VwMCu6g zvy`Eo)F<#ZqE72$26K|zgpSgNiO{6vW2t?}--pr}b#>@|?a`&hcxHKyz4{i2G>DWN zU&4l>ZYbFLKMRef`Kkxo_m<{96@?QWVe^(I2CyY1ZEZ;tm%Ze=<*U`y2DY-sSsAY* zzJ~ZTaaWLjd$6t|Jq(=KCq8QToEbC?{8*Gn55o3y&##}?5fkw$}ZZ2+5i z#->3O831yVOGld=&uD=@ck*=k7qfpb`wy$1IuG}~S*|rj(+kWS#h0OdRwv&EX_=DS z%~YdFjbCb(B^m@>*)>dUpK+(MN2BFHbt`+rE_y}mmrl*8W3{%Fy}G=eqW^mieWd7* z##7a~l($p%tH&=>i%*rBD-CaIp(&vAu0 zvG-!1aJ^!qeJM)ZecSy*XQ(2hDb zcN7eCccm_x2%-Ce5oTEB)-k9c?Ynqfc^-rKQfT>2Ba~K$_^Vm!L3MRx`_Y+eV{$w_ z^0=gv-QR)ux~(jw#4%5)stDt3cUN08eB%y`-nE3EsA`DQZR4)?VH#PwSitUW@Q?%3 zSl!UqOy9H<wJ$=}m30`Q>wUhp0q!v8QghOXgSyhOpI*X1e*&Ejl1q-Wj0kD#b;&RXXPAS- zv6co)HZ)j>b(LY2Z)l+;!}Z2|STG3~S*685FBMyI9DzZBdb2|{_96GblXE1=hvK<{ zOhDP%+X=CFmX-92PvIv>y3GZzUc4dh;Vs-D=#&=kBDqR3Pf`R43@qO3*q|Wka#spc z4z?PeqMg8iFfWQlWXokj$x~OwZ?Kepy>Jt936dzS2>fWPu!3&7fx0MOVTP34n+dH6 zYDC#T3g z7S?az1>&HHr#*EY>K0*$D+Ex7*h17I0feYR04d0^02+}1@`RU=0*KJMDFqNBzMjNm z*%|~8=Y=YjWeNhAfB@oz&JvLTnrlf2U{VSo1Q#cI!#AY>Ca&Fq0Ios+Rj6rYJ>x^1 zr2v|t0RA)tP0E1yb&Nx#E?bzR9DEK|Y#aBgv7qJd8{tAzlpic?y`f}L;w0v=op^~) zGxakh(&YIqhJKqwN>xfqmG~?pWb4)$+93HH2}MbKo@9F=>WPGi4Yk;n5#&{19Bd;% zySD9KPMyo8X_Gf_lc7)u(uG2(ToKL`3fESOwJ;^LVT~7^7v3dky$EVs@dc9K0SS5| zGpFR#YAsWTV-2DI6=tFBVx#lbe` z`|fAR+Yg84RGzTcv(M~Wi^WdusX6S?%f`o(~&jE%c=|||{4fJtKqjg6k zHY~Pi5Sj&BwyJ}1195ACTV3d97;M+yZiIxGu8M<5xfJ?lDcFLyGeY!9=48RH=F+N` zTP@Ad|G66rKNd=SB<}s{q5Gz%eT@K0v~|hkCqs#2{sxD!B|w~>yA3+C(9Q77Z7yDE7j&-vqVn|EZ+WB#hhZe0x9s)g4(`cK9^c^$a& z>#g-nN3F&dG+Vy@pOJ(&$358*Nr?FaZI;f|POMq*s)ops&_x^`qEKZ$9+G@7F@ZFDoZRG$TG0x*t#Xw zY{5db?16=>^f>TJGjU`f%MwQ(vq$Vv;>Z#|w%`L84|-=RZb?+?l|~ywz8V@TJfLpD zkAcjH+9H6W{a(s)ODTFi;#L;3%PMB?C7OQQB;jwyPbV zkSOg)6hFEpeoW$?<2SzJO?W_`tNk`ly&>-TCu=~Q;e;GVUBqnr?Ic<x?|FM|_P<|Z ztB7mko8xxB#5af5=|1ghLyT`;`yS5fyVv*edV)kZ$9Y_k=q74hCyFjh+Wn}pU!uSJ z5l;-ABiT-qDF^)Fd#d|Gq#u-sCNO`C9 z{%Dmb^kWjoME)v~@FT!6S=XcXu*5Ns0mnR6J^t`~(#4ydlhP+JJ+HZr8TXD&) zK&3P_L!%Q5UlG4g!XpSy0#gj{BawiR0bt9KxKzFuSVN3agoUJadH}g!3rL)~4@eZx zqDYq_=ZWG9o`(a8p^x?=^0p2+Z6CvyBhD{%DBz=AD&jRk#BG7Y83`PU8B{A_C2@(N zS&}E&lxgu+hQCT8-3&Z(1?CKaL1zmtC_$o5tZyd5`d(ydj)GLFq27KsUtc1556RDw zaAriFgjh_z=gM3gk`uqc=y{SOBv+ZioRJt?=n6?)c)5HTU?3nrhdCBVIGmwhF(JXc z&`a_>W2+=fBsG#cNrR+GvP^OfBsUg%8jTr(WnW`734#)}~qN5t$C)6QrOdo@L6mXRuLsb(1 zs#Zs~m(i|uLerP>{{IoG61UK?4ya1p#=H0+$c}(i;zMNrDo9I_++l>2fLILZ*8gIM zBAM=|HB|VMi5FGn(<@-9kiU0%b z#9dr=V=62`IODmC#hc<)EbC7q2K$=yU`xIDDQ1y_B(GZ(`T7lJvLdYHM(ET291~t4 z`8)}`7U> zosAK{ZQ-wtNnGGFVPE54X}F)6z^Gb*K3kY+PqmJctw#bt54L-yrGlO&!p-q7P*Qd+ zP-g>m28T*1p5J2hUxKtcCVMJ-$8q>m_Kvcppl7DKiklKM?frZ1Qi-C+4;Ly+K8MF0 zLm_z$2{_Hi-p+e*5vD{TgZMzSXnSc@L-1Z%qH4HL0j4xJ6d8|qaXCzPFC#w5_?Hx0 zd8hKC@*c%cicd?yw7?6%CGaSMtRD)6)i)5A#=Xh(f`(gfruYQHmJH*%Yr~Ine;w^3 z%A^BzM0;>M)B@~@|CEB$w`Ce`c(&8DHGrj93a~uEy*K0m%UVOu6z-B*)wI0PsE5B? zJggfz2AC>q@4bQ^U_!O&iQ$19XY(JnZR;(X1qg`AQ1V;n|y8E8i40no7$5 zeHCUXSgWv3aFcDx;V%r#T0Y6G8drK4kk_a3iJ8I!*4>{|IF6w*MbBtiExa5(q=$D) zuoY{#9io}WBv45nTqm(WEN!)iBGg)miMhE;moDKJ-Imoz*?Z{nqey?VC7_m?Nhqpd z96V)QYVznQxMvHB=>uZNg9M=4Wq`4bW!wiW;v$)}`lzS^OQO%yjYg;*G{oh5!np8C z&oX0%Te+s?T`ycM#~rV992OK? zSzw0O_9VC!nx<4cvZJwD3PW7Bfps>-;?7}Y6n6@Q8p#4k9NudfQZo#j=>x#cYU#g`kg3`3>NqRu~F^B^L%s_?%5En}Ti4=wu#vFzyhE#?u#sy3%ObZzqfg-^S zn#?bOvJ8Gz{CfIMnd$D1e(|LxnK{M!n#{MDiz|z7u@oehWZYs)EGS6LOIgWK#0pdm zCVqM8XXNLm>Q`kX=4WN-m**E{7Z)TZr|Ku?=N6Qfq!z`OWfo`V=joTER+JcJ<|d}6 lqS_E2pP83g5+AQuPcbDX>0TkGB4PqWXQIdkR=?AbH3osq+arDx-R&FM;mYq_=EI&NLJ zo?G8-;5Kv{xsBZ>Zd13J+uUv8wsa%hNH@xjc3ZhIZmb*UwszaNZQXb`!ENWZcN5(X zZjzhqrnnv5PHw8(+3n&oZdbRP+uiNq_H=u>tm|?)=Um=(yB=3?MK{g$x{~X2Wmj=k zx3{agx@)+mo9_DE3^(8g-9Bz#x1ZbJ9pDah2f2gYA?{Fjm^<7Z;f{1ixue}N?pSx6 zJKmk(X1WvIN$zBKiaXVv=1zBKxHH{Z?re9CJJ+4(&UY8M3*ANTVt0wV)LrH-cUQP8 z-Bs>tca6K&UFWWMH@F+!P3~rQi@Vj`=5BX)xI5il?rwLFyVu?4?spHk2i-&NVfTo8 z)IH`NcTcz{-Ba#q_l$eiJ?EZxFSr-oOYUX&ihI?)=3aMixHnzPz2)9^@3?o}d+vSr zf&0*XSU}abZR)y7IIIIq9z?!fYtPShHy09Ls4;#RS zun}wwo4}^98Eg()z?LuqM#3l<4O_t&7z^WIYuE<1h4C-}wu9|qBJ2Q@U@}aB9bqS! z3OmCtkbzxcH`pEafIVR^$U+z7z(F3mp$7_3glW(VCFp}PRG<9b90dOE31P8+*a3~xGhrfe7FEEgp1%}xCAbR%iwaj0N2k*lN@F9EzAHyf`DSQT>!x!)+d<9>_ zH}EZd2j9aF@FV;LKf^EZEBpq(!yoV``~`o*Kk#4pcjkZ40{=k^{0A+-Yyj;LgE%B0 z2`NZJ2Xw+Pm;>g7xnORX2j+$OV18Ht7KDXhVORtfg~ecTSOS)WrC@1T29|~8V0l;p zR)m#cWmpAPh1Fmbz^1SnYz|w%mM{WF!YCLG zTfrC@3*%sG*ao(R@h}0lgY97=>;RKsGE9LTVJDahJHsxJfn8xY*d6wOJz+1%LKoz~ zK_0rH2MSPxY0wKL=z}s;pbC3K4eHQZakupjIX2f%@F5F8ALz@cy$ z91cgok#H0o4adN-a2y;DC%{ZN5l(`W;S@L(PJ`3o3^)_cg0tZqI2X=?^Wg%x5H5m? z;S#tME`!VA3b+!kf~(;gxE8L1>){5t5pIH;;TE_RZiCz54!9HUg1g}!xEJn&`{4n2 z5FUbu;SqQg9)ri>33w8of~Vmbcov?6=ivo-5nh6q;T3olUW3=+4R{k;@D{uc@4&n8 z9=s19z=!Y=d<>t!r|=nk4qw2R@D+Rw-@v!<9efWzz>n|~{0zUqukaiE4u8O(@E80I z|Gd8{@x?@sB+SF^EF~l8}NlbU-HzgE?SMm<#5Hd0<|c59WsjU_n?2 z7KTM&QCJKXhb3T1SPGVgWnfuY4wi=%U`1F7R)$qzRagy%!|JdGtO;wu+OQ6+3+uu9 zumNlc8^Ok~32X|R!RD|9YzZS^B#eU5uoaAfu`mv{hHYS57!MO*JJ=p3!VWMACc_li z5q5&9urure8Q2wegWX{d*c0}GEObE*9OR)JdY}MBmp zhZ!&cgRl?m3;V(TZ~z<#2f@K`2pkHB!QpTO90^Ck(QphL3&+9na01MP6X7H{8BT#y z;WRiM&VV!FEI1p^fpg(JI3F&63*jQT7%qWJ;WD@!u7E4yD!3Z1fotJ9xE^kR8{sCn z8E%1F;WoG(?tnYtF1Q=+fqUUTxE~&X2jL-j7#@K~;W2m|o`5IeDR>&5foI`4cphGW z7vUv%8D4=`;Wc<2-helu1#iLI@D98S@4@@<0elD_!N>3kd6=7%UD;z>=^OEDg)RvalR14=ccm zuoA2ctH7$T8VrZkVGURl)`GQR9atCEgY{tp*bp{?jbRhm6gGp+VGGz2M!-lI1*2gr z7z1Nr9Bd8Sz_u_RCct*EJxqihU=mD*DX=5#1XE#W*ab4ME9?fl!yd3F>;+lqf*d%= zLpSt50g5mUdZ7e;P=*RrVQ;8G9U9Pt>Cg`|U;qYTAJ`Z6gZ<$EI1mnkgW(W36b^&K z;RrYqj)J4%7&sP=gX7@@mi^Z0=L3#a68-qcfwt8H{1jF!hLW*JOB^E zL+~&>0*}ID@HjjHPr_61G&}>(!gKIEyZ|r4OYkzh0l?upBH8E5M4d608iXz^bqs42RWW4OkP_g0*2CSQplV z^`bOoSa^5=@3E zup{gQQ(;}8T9T{X8qkF4 z&<`_U00vVLGy* zSOHdqm0)F91y+UCU^uJ}YrvYY7OV~Hz`C#=tPdN&hOiNA44ces?|CGnuiX?jgUw+J z*b+v-NEij9VJjE|V__U@4coxBFdinrcCbB6gdJcKOol12BkTlIVQ1I{GO#P`2D`%^ zuqW&VS?GcsILJdc^gscMFb#U41btA33RGcls6ibX(1hvG4>Mo@24NrA7xshw;Q%-g z4uXT>5I7VLgTvtnI1-M6qd~iX&l4UC$HDP%0?dRH;Uv&b;G2Zn4Qvsf3a5c~1m7V% z1I~oA;A}Vt&V}>fe7FEEgp1%}xCAbR%iwaj0<i^Z z0=L3#a68-q+8O+v@GiI;?ty#ZKDZwqfCu3rco-gmN8vGe9G-wD;VF0;o`GlKIZz$F z055`e2bLmy8D4=`;Wc<2-T>8=>gg?b8&pT{!h7&Od;lN9NANLx0-wTX@Hu<|UxMmG zI|L&LzkzSzJNO=cfFI!}_!)kIU*R`U9sB`*!e8(=`~&}me`hAp5(%_K0xgk1OC-<| z3A984Es;P=aC!tONJ9s7!Z4Tv=7hOmZkPw=h52B9SO6A;gA6tPAVG`mh0P2phr1unBC+ zyU!uq3^s=?U`rSQBViPbhOJ->j3xbDgyUdq*oN!-3AcsuFafrM?O`J90Fz)cOo1I? zCzuL5!!D2^kGBbTh23Cx*n_xt3HOA(APZfP0|$BNh8`$D5vD;el%NmFP=PA!4K=7k z1DY@$`e6nPz#!}c`@(*(KO6uD!a;B_90G@eb`R9<+Cfm;9|1?gQE)UI1INN~a6Fs< zGvP!y2~LJn;8ZvbPKPt#OgIb92DQt%pj`yD#rbdnTnHDz#c&B+3YUR)5>)3`z?E!P#a1-1Nx4^A%8{7_ez@4CWz8mg=d*MD%TR#8~!b9*dJOYoxWAHdU z0Z+nH@H9LF&%$%?JiGwfO;A0(1TVuY@G86pufrSgCbZx!cpKgU^@;c3efR)A1nnqX zPWUl=0-wTX@Hu<|U&2@LHGBi#!gugJ`~W|~Pw+GR0>8p<@H_kgf5KnzH~a(tg@0!z z(HBYdMG}3HL|-J)7fJL*5`B?GUnJ2N+CkyW8Pd=JoiGgMfH`3cGSd0{@79~OWG zVIf!;79p>X2^WRMU~yOimV~8XX;=o9h2>y*(*I4k0;~ut!OE}-tO~2aa9ADIfHh$) zSR2-Xbz!~0`h*+6hOiNA%&KZB&^tGQO<^`bOoSa^5=@3Eup{gQQ(;}8T9T|??>d;U0Zo_={V)RtU=a3!ePKV?9}a*6;ULhnhrpq57#t2qz>#ni z91X|7v2Yw54=2D(I1x^Qli?IN6;6ZG;S4wv&VsYy95@%wgY)46xDYOai{TQu6fT3y z;R?7Cu7a!K8n_m&1MNE8LwEz+2sgpaa0}cDx54dj2iysF!QG%|?uGl{es};Lgoof^ zcmy7W$KY{z0-l7Y;AwaUo`vV&d3XU{gqPrDcm-aC*Wh({1KxxdyajK=JMb>N2k*lN zpq+>B2|t35;S=~2K7-HU3-}Vgg0JBl_!hncJ@*6r2tUEk@C*D3zrpYD2mA?t!QX*@ z2(`6Y|gnL02x*!J*@}PS? zP=F##gI*{>AC#d2Rq|;N?hQ4lgRYw}9r|Gg48S1l1N*{$us<9C2f{&cFdPDh!eMYY zXm{cW!Xx1*I2w+DW8pYB9!`Lna3Y)pC&MXlDx3zVgLWv+BGfL0-;qe!_1VOo1Lwkd za6ViB7s5qwF!fWt4ya8`Q z3*Lga;T?Dv-h=ny1Naa=f{)=7_!K^a&*2OB625}3;T!lCzJu>U>yRYjkMI-x3|foK zP53MP2EW4}@F)BQf5SiUU-)-s8hw#QU!>6$X>>&zU6Dplq|p;;^hBEcwX2XuSESJu z__aYN41+mfPM8bkhIwFK(ytC{BhLr(!ve4%ECdU~BCsed28+WI>?X{UQ3;luJljf3 z;dqxxj>y*B^lYs+lOEAEP_Gxu#cVk$yU{mWhw(?aI~*nb?TlVk4VL zjVN_x8?{QkF0oYX$u=rVRH|lkx?Aqam%59jZyQl=@P=|DpKXwhNu^3P&kcT3BWitF zS1gq(T$5LP#Pmuo-^e6J%&3$qjZ!8)GTT+Dloh%jJwGyADwgwg9vzu2yJD_VX=K_* zW@}x#RV!z6CA(FtWphoJ&9sfoHh4In?IFIMkV5i1ovhkM=JUG9H|&vOxzVgu$~~ki z)*LsA6s??zjV#uAe|lu8QgMY!v+PJvv1}?`Im^7#*OeV0ArDu0u2QaaRdiYD8sNF6 z%1!;$GKrDR?(Qs~icl)KN)JUw=&m6VRqOMg)of5v-KBv{d{nk*MHrRs%a=0AQQ7K1 zSGJbTBu4R@%a%+ic2%00)TnH|Ug^))1~P4MRC+09 zp;+x}mU}5b_0^+nin&6mNX3mRW&1L*QI%|?sGy5)NsX%HdTW(zj;=DQLcNcw(4F$S zUE>S6UF$=l6LF`hY9v{|E(R(Mnsihx+eKnh=woV?T!rpr@s_GtE9I5CS~MQp1Q10j>#s+6x5R8O6eG&xTNwry4F z<)^QiiI2&aavDlw^0gXcEH*}smuVlPVVWF6)A+9)Q&KP202xCkjE!N8^QtkG8JUhT z>d1|Jmj+1t81>uu*ldZ9$+V57W!lH8k`iOHgZYxGZ)`r-%ZM18uXPovps|H)tw@p5 zW2x!JL`t3IX0fYFoo{ThmPO%^N16WHHn!YEq@!d~V=E=sZADFbQCethr9yqi$5yIR zPKmLVn(u{UEA>LQzc0&K}V&wp9PihkFPL}sjl%fcwK=FtMLQn89Y5u^IbDG0qw}BpOEdTqDm4I z7@x{bLPO1kKrcgsB@w7U4c`frv!FQ)p@GgLqTVQWcPnB*V>K}W)zGURIe`bf`fQ(2 zMA5ZRFm1sv$&wQ^52$x0B-I$AaRey>A}X~4dB+e)K+^i&&0?dX92)fM3I1~?G^xwP zc8vA{ZP~G1o4=`NeJ;-onn`S* zYgDTAwAl8vcQMnpeW{PgDxVIOsRLt>C zZkBr(fD&qTDd6P9LOxrU9&6J$jPs+(seNLxmodPvvPM^w3t7^ZvW)?KTw-E{`CA=n zVx`Q?InnpW)Wk|H*TjaQt`V!uZq(33`}RzFVl7)_97~JGko!@M6PckYP20pKJ$E8w zF%#RNnMHZUcW5$8(h~_AgBoEwG%LLniTgBrCO#?K)m5arCuMWkBz(}MY>(5O_DOoa zZ4xsdKeQG<%t7^XmQRUI!W2>Byt9cXiQp9@Qju2YlhTuzOt6YDwn$&;qI~U>G}ERh zRT_={JkuTDJ&BTM(vumlC=hAyIC38aA+GG0l4J>_Bv3$e-Q^j-z|d{seBvn>HO0vQ#tE!58^ z7rU^5Xx-T4URegQ$#hM+*yOV6Mget^oLsK-qq1qC$rZF^kqpogHO=Fbt404!RfZJR zHko0=oTq@f4BKdOBU>}gVG@%j#?@qvi5hCpq~0#tG&Z@RIYoiA(&TYkfsbjo6fYL&nz>ltPwH7MtQ#*OW?@qD-mO8oql?sWjEK z(^G1dYBA?kIC7KDF{PQMoM`VU&6*T~LK$T79Z|+o<2y1K>dj2&j(uHxRw?VB-pN`l zu~V~ziAw*UDt(4=8Aqt8_fKWmY1(O@nypteol|A}MaE`)D&t9Z!c_E!9>zSYGUKGD z7E3PQmCvJ2+ox7a-I)}>7$fQp@u@OuG2EvzkZc%p6BTdzDJ`+e+^&8%72Ag7nr-PE zQ)~3*wq3F`aV9VGuAKnWBjXidd$x-r#4zbGgW}n)I(BF)Nmwa08+ntcV&*a8&~efN zNmM!drkKm58Cm&i!?GY&a^Cq^)-1|cvCSGEpGion=UppKZpG<&3!B+eE5@>bH^s90 z3!GM(pxB_()0LCFW^dX@*|=%h97W5_k#$VNMKmJTAw}j6wy1veXK zV6A`&p2||Gg?xiT$FtpN8Kvq*t7glsRJMmXK~@w+mq^)V6J`vKMxvWY|Wy@$;nY6k$L)nmH6@$bNJc)D! z`WTWFy=trjnGt~+;M3a~1zkLYVMyKbj$v6Tng%uu4M=CrvKnGkY54c38>H#s#qMHG z_9L+k+JxFoup*`S`h1%_;@NshiYP^IB!z#Fg%OX@SLg(ditkAc>t?FSbQXJpS4umj zvP}|DXX78^K(1Y@y$dre)7B-$ zz~aDn6e4T%JXX=sf(

dn@{MJ$J%DPfiyKw_c zo2M$gmDE^RL=rZr6t3yV(t3pIujnUqmKv5j`TU^WDdsVkx~vB(h0G+iK?0E$L*aL^ z(xV*+w5-%6vu>=b*(WQzt63}iay09hmwaJgCdrs#gMh|K%f6TZ%deHtRH^wco-1J9(Vucmgk)!>FR&o6VilTQT?4ITu83)( zNja7)PUnUN%AG6gjfq@I-I%@-1Ze@bH&7xLFvWAF)nhV8`>8~_!JAv$AfOQmn~sSh zS8DcAr-@tzqn47Va+RJkZ9+>Xa}`WDjn?*@+CR;2iE$Y2(RXaj5nBnLkHnHtMMYvl zVgcalYYD3+M?2%ciF`y-~er60HF*PYtVmBrVukRcZL%o%|O{*$NJ!c(JN$N77H1*g6 z;fm(4dop%HUneECLB^`ive#*Im{e@2Uu){pGmU&-mAR>=Zm26?5p4665v8*gTP!yz ztUgR)XatirJ}kbb*0*|LtT0W>Do0202!^W73VIuIc(?MM&gwkmD^dBTqNlF-joFQI8FmVBCMLj_uzu}kQ*YKtaX5;kkrN-+c*)lG^< zb&|9;ot)ZbL1gfMENW5cm{K(n;}4C=%h?0=skRstfr*J*G?RdSv9I(?C8}Bw8NV)X zW(FsjxuW`Ci+H9luDi8eO23liyCyCY-KI@-m*Fl0!m*6PvL>mX@qJp)Ao)pIBIW=& zSz1J!ap<)bs*}2Ek-D0v8}oq@Wsmsx)>&jpy%SrvwzGGjHq|}I4OL04`?S+p4ziu7 zXcP5`ly^EibovGiX0KsPG@FLo#U5#Ko+(yKVb;Z6%(?^4qqPP)(iQ6}1GbPPwoSCi z(`&e5)>bv^PL3St0J6Ne>S01B5ijtgMY+)3>17>}B;7X5U&!Dt_)^xfJ-R~n68%?@ zR@NOwqi?I-i-YQhuENqrmVzW+&H%-!t*II#rrBi`je|R84u{S%z$`7D*Qp6`UQvEb zFRrE)5q}E9%1J{R!x&nFMCO zOukGz0(F;XV4Ahlk#{Ue!f=f-iLoV;%F`rX$4jKUSO>^s<=Du~#1X9)^Rhz7o(MVu zorGmDI|`aclO|$2X3~+Q_Vlt%b_rxN-$+Vxt8zN>Ob_hNTQ}<*x=>Z5=#`OHL7TBQ zwv5aFnSdy;AcjsHG{TJl7(I<(wYiC1?Yi z3V8{`o_-!!L$I)l23b(gZe40i&m zup6VqhDN-*#3l}1olthYK7}|9dU9u8uJ|3*Ze~eJ%fzZKj!?@Nd)T}}pu9XzSakMz z9kRL#*f=4NTBE6r&z2t2N`fKsL_pWhQ8;GTR-9iQ5+n)tZmRQ>F0X;b za#jORr8Rq>!RPos1`djlU{Kpy-JELhW->#Y4W{%voFS5@Hk##DnjdtKnJ^v5%%Hr1 zFNkr$QiI2BCMLh8ZF~?f-J9s4MWuGydNg)=H12sSX@dHQMM>YriPHXxBDBq)>BN>` zONiqcwF$TV8DEc|SbI3#u#b=RFpSWoJtZdZ6u%rnDdQxkh;~fEqooQfJH@b0(!F>O z*0VjtDvS6Ex+ug;7m$stDO=zqDMHScEeRx3`Y{BZ;t(|af`?L5bvpJ)nTF5T5+bm5?ZwbmC7M1Q!Cvv zsf#-ES|(N?h7QUah~CO?y}*SmnFPP;^~7?DO6?c&rK*io3DU!@IEPbI1P@?(40h^w zI14sbP1GeTB{sUOkqSjtHq@FgOHIBVVVbQMd6CX9TCrl8v8xhmgB$I_JIWQUi}+&lWzMbpRVt5qYmP6bQD-fE zxYVA#H(Z{1V$+%$1o3IC#reGUX*7JMy%*inYGdQN)!s{I z$i#Ztu`jpcz05$oKQ8OBS!=}@;ywAn4hx2B_G2>BEY;hqqjxN<|kIQ^FWeG!x zKp*Q@___!+M)GoxG*4W5Ag|YGv7u~ccvA~ILH#k|*F=1Tt$oN|hemPz@~zy}Za$k( z8Y5T^X+kXmY`!QJ{lP+sB(|6)l}>Av30;CKUa2CO@lFvcGV&5b8*M7iB+@56dhnrUh*Gg^3(9Tg5)(wqlronWKWWcA@V5~Dy*GSiZ6tp`H%9R$j2-HE(wvLW>I&B>zV7zQH}vxos}XDSk@I7%y_*+&HsZpT2#pD5_+Op-5O@pyM_AMCeeVh-l`ekBc@1DVtAW$t*+1jOu^J~kOK0Ca% zi8rRxHmV9+TxV-sQszujvzJUU-LS!zAhgYR1l)K?LsB(MyN#Q;elaO@#Rrtx%3xQJ z^-8j=;|=Q%Wen4Tbb`8E0#%Q#U2nC?;eo5BHe+?w-)g51wNuK}Qf|^$%-$$B*~`-OSLVfN z6-CC7_3;(f{eB5t>6#W6#tB}`?7%cy=_*(JDF<LRHJ&nac%dGZiX<$+HNuhDA$OG-tr+S-rw1HO;=h4H zMc_J35N*2YULCuKG9fU`YsNE9 z?SwNq;Z;UI(MB%QG1>UzA__MQ2{mrRyf!eH;m=H7CM=^@}8HZ*OuV@vxSZ2~0BWg0xZy=s{`HG-B9rrC7+;2ywQS)da&>IH1CD zI`W#TS{Sg=#%u8dX~n7~ohl>5gCT`9M3`6NJre+IyMauM< zW#3*^t0Zt1(m4V3rGCJ;GRbU*LB%zOvy#(u+VnLoU#;}}ql#2jXBwfb?5ugi8}pLw z#yp?R*@=9?=&y)9+Qg!*IC!R;`IR~seXMqQV=Yg;+c#kb@uEy8QR*80opG)hNtV!$ zkOW74xn9k<2uv6(lqKj33A|M&ktH+T!%|@@G5rx|pUnoDR3)!wrKPc*Mq(?=s}|&L z6pkN(y2JW9v^28$Y4Ry7_LPW@!K<_iR>{3+x`1M9u-QwfN?xyr*zfdjQNB8pR0dx5 zMkr;}6s{=S@TUA7)%pPEKlsEl`i9v>v{;~x*%#EPtu~o#`DOAVt|q;mAoS|7%Hmc# zQE@h$WSph6@x%9R8Krv|nBdMhblxUphRf{KBMOrJz?2xWSJpY;`D?*mK^1H;8 zL4~sErtNa=`&zFiLQU3&>q{$!1*fw?0v0{`4T=U4mB6n?f?6R?@1YiXj7cQHPg5I2 zJmR7gd%BXJNq7uf4xK^*Quf$!mm(2Z@o`i^U!#p(D`|qN*`l;bq9||tO7fzbrxbTLM3lZDV;F%0j|Sv!5SlN_(o zI8bSiTvKJtpQS3WcEk~cTv#}e8TEdD@hL9-j}+`H zu~4A*_7uq6h6#ae&$xC77z*unJi?XopoQ3q@tMgM&^6;p#(q7^+R|m(>X@nRbq+Q% zF%AWpOglB{)VonU<_qDk)QjUR1@uTvA=Vdxwz-H&vv(&&f;pQ>);SBZ$wN}Z&W#gw zPAmEuJ0xTSMCYt*iQ)0cPMETaBT}a=g)%t_wM*Sf`j9t}I6Bl!AbMA!I$gw93>r=A z9MVUvW|DOdB2j%(MLZeydg|J`r77y{V5;tLL+I(jSRrI)OAu)Df^x8TVqZ*H)onR+ z5$Gg6ML;h46du4;j&Z;(rsu)51-5ih-a1HcwaL6|XS%_;wa$K@DocGnUdObgH^vBA zHYkwNt)pW#YjB$l2UQ6#lD?_;YS%)o!5WTY#LZ>GLRrbY0`*Te{Gu!fw{u- zY-Cne1niq6>N2I(vFm*meZNACiosO9G8`k+zee5{=oIt9kXE6$RJ7jHLYD4EpsvF# z$f=`Et(r2W6?9K}i(eUnb?dFf;mDHY*V}*bIx3i^Pt;{sFpyb7LRuwhHkDjPbVFmb!$PY9RT_~vqF5K_FDN+5=RqA2 zl3YfVF6E=DKwgLhWorkXvF5{)vsRr`PA-y$HhvW$3Va>c9vWZJH{Q7rJhyuaW1qYO9DK&>UCQM23e|; z``SFz{h?i*WE~$_Q?jwTHioIZrY3FD5^%m|p~f#yVJ%0^Az4#;f>>q%jKjKS-W0#s z7`j7x6ZNSSll2gDm@brNx<8;JjrxLLF4X(^%Lr<_zi2&}O%--c>H~dv&RE-F%xej( zhD_JB^1xUPMg|GcmFuG6}7VSjQ_# zIO)&CQRXZodDw3O5BPFJxx=zYDmn}zSM-Agn%?c*ae?Fc-i>6VJ8 zH*wmbIk=N>mh)hOD|6dQasCiIfh4jMq2=B%#Q$+u>Ck25xxodW&8|9OCT$n~dAkw#a}H=ucrX6vS_|BNVl5s)E{%As?eb$CbA673!KoYQ|ui zZSjO{zVWgAiHPlNk@e6bT<^A(2E#7HV@wzYIaKgi)BL1}!@L4c6C??mW_cme2h{{|Ng9j_ z)x;9)XsdlXU7!05Wjacd(0XU49f=WvZNAKuJ=mcG)5`EhRtFaeRvb15B7{9v z3Et*H(i9!jb~PH6QH{;40uVzc4hi zx6_B%1Mb%tR-MHA>2mb=_I^C$G6{bDK?wqxw4pgj!XXaL5W{O*hj0D)5?u|&gisnx z>vyEQ&LKjZxj7iyxtpK2LQ>34WueK9Su zj{1`qid+iG(zm-V*^-R)v%yI<_sf^b9K7QF9AMFX6aDx|^QysNK?u8PB(0U4^0g$^ zs!8H}Yo?JJ1qx@Sl8imn&m3xoXFunIOo9X2NM!4YU1yYSY=FOZ!NM5SR%Bw3db9q4 z^RwyVP7Ph0Uuh(1WjKD;f4gvMD~51X2G(Awv2JIwHYHg z`7*GzovUxh5jV^3ek3y#BUvW_k&KOCCFs&#%eUMCcVs&i{Qs!+SUDjUIXL>X~?#k4SkR-VkNW zp#-|oK9e+7E<#@?gjPBne>pkOGwt4!2#;UCIfh`5BoKW4bg`1Ld-(rau1O2^Fj67w zJ2eu!hl<2|SZ50o$;5vk3DMt*k`u8H{fEg-*gwVIldn3wfAhWvNb=Qi z{!l|v37m(m!zR15-1JcuzpTZ9Sbm;Lgwvgb6kbH(6!pCx|G%t*L$>($7<2Dt5&KnL~+D=7bw=8havvXl&OEwBS^4|5jpKZ-+k-*vsa zvhL!Q%;BFGeCu$J)2JxQe>w*x+VoTM5_?7nX@pv_+rGNP&QZbz-p2j}hf*Vj@6`8) z=Zkut?_{S`F{ENm=q3G~uTSND&gJ3BN1~W5A-cPjm5RfiYPZ(+o25k*$yQ#MJbzWOnALqf}?k0vkkppZorCHdhI>g*C|tNN$Z9R zW=myXFmo2R1zD6^wAG}C?xxM5 zm#N&x%0_?@1cDWDT0xSmyu;+j=S~{KD*v$nY|<* z;TP*TZsv1YdqGN`y1s8H((6^cU8#{U^1M^7N{7#_e@HI3b`(aeE8x#Hr{ z5Lq|g8AGKY6`Nf}8wR|du3;bIML!`BtBs(Lm{yI-s$E7!li0Se5j!G}UBy#I8OZiv zk#l4}7;tg5A)NBgAK`Xgri|p^gBwPCxC&+gk8t*96)Pv(p(a;6%abT8^DXpmWj7_N zC2ZP|jU0RRcg)f2SOs?Z_Vezy`Xa4>e8)4-6eZ=~lvQAUxAxuM1J}p#XXU;J;Dzfo zC_)$*!7Qar_98w%@NxWGlhvAm7+vW11AHlO+ApPZ6EicdUw7iz}0~ ztUc{_m~2bHQ1}_DPX8W!RKrw`Yih7|^E24B4Uf&xAB6axnV(Uh_eWWEm}{Vjj*uo^`V=|}ubtqUf6`Q7<2<0rc##elDQnGR^<5iZ1Wvgvs>lEFW10<(XgP@ssIM?Q zK9wecKc}}+drCI##C+-HsOc>d1l-~5Aw~LnLq!H#OzXZVj{SA3RlGBOGx0!B9&90b z-;hrII}NPi!fH`hR-Z}sQ+!8Fu*R~e6o2ceztvPR<*I3UOR&CTuAEwENB6xh*9MdC zojfkL`*7RWZoCjappa#=gIg^*c=`sFr>9TM-@ipdczaM?=KZN#^7-^hSm=8p>MPiq zrvLHP47o&Vvk(~u30yO+w21X4w=z1H;7(S$f433cu5+PKO-V>NAP;I z3%vQF{)*QiP2g*lXmchZIS=Ccs0WdvnJ;p7^s&{T1Mz6I^Jikozz!e~$rokohYas# zc*zHn8+M^R`H*?g-b8!HM@5s5wu&WonplqvdtKoKVI{LiYN2qJj?Ho3nf+BwdxKFP z*jIjCp~viCKD^2PHq57z$R@CE48zM&6NFrYxTF@?kdq`g$4JfM${5xM(}JE&Ory^$ zSz;O|OhNDsiPvJUk6=+J0r$loHi9j2Skk5CxEQ`Qf%~HGfxY_U2_@@g*DMG@;j*!8 zZJz99!t{CdszH@KJ9EK&MGDwDB>^ z9yKXUVZmDC!Jxq<_p*6b=#CYI6!tARF4AGb+>DQ) zPsAcm<;8=rT$H`jdPSLoyvX zTh#W4_*+&rPq?w{Z&RsxQP_F4E6 z))~}c!jlh||LnOC?GQodW7e!ZVH3Bx3R?feHIP4`v%Z1@Abl}<2kSMx!@L5cZB70} zCK#7}zP52N^!-!*f|egHfi_w?g6p8=oZ=tl3Cm5q{k#=|AE1rAG+uwvu*GZ7cfSst z{MZ=^H6$l|``&V}v-Fo|awDXC9o}Wn*B1VJ7K#5JoBX<+!a+~}I0z!nb$MT!;Do2! z{KG$CQH&UMo1{GM?3omfcA){}a~Bj@@V>KGCFE0QZwhX8BcBf=Sg%4+cA+ovTSLa_4yjSt{(hiChM6%A0eFFZ483}lmMZ2pKaf(tb zNfG?ytmx1~C5Fu_Iou5`^p#xRb--0(A+77ra;OV8BI`!k$T^v4+yW~VOmg2Nbc3o4 zpW}V$NQ6tES5LanA)M}acqh4_Rha|)>|2dCXYrl$9VRU)%!VoCX!(gF9b2Y;+T-7Z zDkxLD?EE)-Upi~34vti5T;Ed@T=*jlc9|@x`PA`fN=MW|zeO7ciRgR;!Ecb0+MdE$ z%x`VtAr-z|PIJtxUiX;nmv4wTkg0UIhD9R{mofV%87xVsF7%f}zG8I6de6@Wy6`Im z^5<{ZgKa`D^L8nD%BRz69C2uG6*(-oS&#**E^@_FVqD@1tZ;@_ynIk7J4G||)~oAU zeyGtAIWE=1Bo0%yNj8H~5X=|f9EiT^<%SWddlJyAd`Fgpp0{-Pd&8ibsfGVh&ZI^a z`ksmy?VTV7?ZBI>!N@@hP9Wr=F%m`;UqXM2I=THtqaWow%fIWWFOlmWYTNfT-pOXG zm5Zv9*H6?S`hzt;z;WS=Tw+2^-~u6uC)mCU=aI-^VkXbXaOSOH7cNmN4A=<6x6fXY zP!`HXLa4){KP;1*pT85CUU>B>LJoZX244~)ER@fljfMn1dwvS3hdDm&z4NSb@aVHo zCG;!cdYCojuBZHPzB7ltAoCY<<|Bv$44e2?Y+U--1+aG!Dd$uF7=t9#sQmZBx?Q5Z zNKSqHE0BJ^Lk^!oJNPa?t3hzZ3)(8?cM|Z(vwu{s75Vq6%GiSTx`TH2wQ0_LRuu2b zN3n1Qi|h&$tRLgg=SMT4Kf5Ew$|#q7+M2 z=^Kc2Q5z!ZmSTKsl37?UBUlkqas>2Rhh#Nx#gWWwI2```8bgBDE)skd;xMScI_wVZ z?B7np(8mz$lkgF=FCw&RNYyc^eZ$FH(3ic}r%CWPTH*7fDcBqdeImR-{O714hNMt$ zxD-aI>@{Q+j8xnT!?8deTRSvjIC00>FihS2(|Ms($vS_0V!2B4yFf^I#Q7J5y?{hL ztkM$oO4N8^T;oU>4yQUy8tU^wfymL&cOo1OLr(}3H~xOXB9l9z&o`~5d#Go{`o@>! zzeZ9}c`WRsFQhtrdHFM0Us*I(B!z=Ad>4b|#BO=8IJ{qnp1@C0T`LlCj`Skap0Ci9 z{vOzu(Zr~(hAe7UAmJ2rXf1;CpYO%;QuKQ-in2~bRCGv!NO4SjQllK zG&z{IRmn&?)RobfHY`%ag?8mNQRffK9Uchkq!S4hLg!FM!*r#;3-mU=+!LcNt4Dk{ zB+$v>VHNxo0v7|4OS)vo*+hqjIJ62wczH#h!<{^Lm; zElIfF!#{m#4|3nPE;(?z85WXUU`mGxfhTd;uy>mnOenb`M#qIFM}0{?M!uj(_L8Jr z6>XeJih5r^e~=u~5lBkm%NT0CBRDd88>Bv(KWb@99-Z)FM=^r&|JgjRH>`i$c zLO)Wz+_1Of;}~u4lEzm+lD^~B3SNytQ6ktphwGyLA>&}3MVrJt5nMXbDA8~%jVEBJ z(Io#iYH(OUvSCV8)WcYAg`>Nng;xqGoXjMmu9}o!Iw$52SmOF~EBln77Q@V*#FsHD zepnsE8>Sp6k^UlsXCz1kw(wgFdmWuf_z{L9J-I-Z>tQ+IZ$(>({#F<>{#InqDBDo2 z`e)dzw7el$oF-Z8r-%y6>_J7wrXN{2D!dx6{~qkSh_ z9HT8fBwuj69>dg$Cu6j1Mr;&q#I};ns`CRO)vU69vFQjg@}S{gjg3V9UTAeRuuc;_`@!dS<#Fl;q4=);Z$cf%Bh6QM6PZ;b@6 zBXHRBC+8{o7zW#if#g#~#}`ESLEaxalK!e(Njgv>g`u4gT|^b#sf(co$4@YF#qkb@ zdN|onrw+wnO=E9J^~(ceUuT=m|NJ>&4gIqcYYY~$m1%`Fvb(WA{V`Q;s?1=65RJWK8*?$YoJ|F*iS-%Sn{kK54 zn$OZIy}vi#>`Pfp^o(Moyd*0>{Wq8W`ezp789lG~=yxHDi>`gBdmdkw=qY~{mEeEh zrm{qDoRya(mC8Tw(?_x|QS|F`j=nB>gD->6!{;yAzei8_yrLNY{k|mr8UIHAJL^qR z`J&u?DWi8sb>nmM?~8I!%0;YhW|hO|s}xZVQH&3x-&wU2)k@S7{?7lKF8bnGIqOOP zCZ+b$tbEn))Y`tZ|67OA7y57e-#7VG(fz0<|2JKf|Ezoe`{pRG|6ds8;B%ezuEne+ M{geOu##wp&AM~s?9smFU literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bf77c62bd76024c62dd9f40fb7e14055177c28b GIT binary patch literal 10075 zcmd5?TXP#ncAgssg9|~3q0HPn*I zZ>|5?~#BcBwONQtz_h-}p?O^Y#+`&=o_h;fm}dsZA16L`;w zNil`@yf`kV@qS6nh*`W(+*ib$nEzZ6^WIeb~gymxV4sV3$IdZ)q+|Z`AS>b?G0()d2sL6ql@-~rd##ywHv~d7j3sC?04_pyK0L; zhc9SzRtjpAV8VhH8NoBrjLbl~t*VFbtY~kx8f{lN8?q6Nm)E@Vx>IWfURBnDZL~}i zcMn0}$*xhXcdFe4Qm=}SPK(n{6ZzFI${LSd%!5 zOU36_pxsq=)EyNop?Cl44vp^HwTkV=OMs%v)nRB#Qv`#Iz>DlM~ zK-OB-!m{0QgEc#7_j`yNJ@SGL*&1oL+HNqrrkTZ@(w6C476chE802y@kbz zhqt}0z-#%~10H-hORYc4%QVYbp5-U`p85sm8mfEh2awBWnoux5mG;4+?x}Yyy;fX| zbaad2++ZS3qrK_LUn}3dx3Jc3dJEk(w_RUb*lf#nzvGs@g>t*u*$6!8Jg*Tm3jsP< z;ssi`U8~-|wc>09wT8bi5P@Q6`%^yHnD#}D3Lc}$ZTa4oawx1h=lUUCdcG&H@Z8#O;yhq9sUs$XCYNkV?1U`4e(tR~BjE)2+y9vEM+^biHDd8k(C zOeIzhbZ?-4THB8dlBg(A z^F>WYM!j8Y6*QSbSCJ~P2s(mu}!55`DzKL+;Dy0^GDWt zNPK|f1R62N)5z3_<6KV-d4hAwAMbM*vpL|V25yr5Ar5dPYUP-tb3cHiA#ougY2TtG z@yK)c+cJm}tT`eb2M_BoM-*8e)Ct;7O>7N#I)NIljs+q+x&Wt%584(*omDM$UX^E& zFVkAlkHjd6=MEnKFOgK0Uc6OWsO+gdW$Rb8_kpsTpl{M2ca=}Vp3d88hDxaKX+6|q zi<6zJUro;4}ij%$eq4Mkro^u~6AxR$2c!pO^9tYb|8ix)_dBgXlBms|1T0&1q zI)`$G658R&*mN7~d@dr>?R20c&5uN;hEg%wn4zT2 zhPbv{&pjCw^q6n866v*O^)Tn@POSmWNH%C0&3u9yik4N=s=SKafwMzm#XK$|17oZ- zl%Bes4%Jc1Xf@gRq38LX`F(?{~G7t8FjeA>Hlv_60Fj*@zxUpklus7E{>9PHekBIH!)0M2E zgk5W)3J_w`-Cnkrmt!no+jMXiM7z4`HW~|CTU&%<`iSCC)%JH+b-8M^S=kL1E-hU? z+=Pl-zNj2o;ril@%S+CnuTgMAtc<{#NAw9T?3J;d+l+FG{m}`_`M`xu`#dao!(DB7 z7wvY3@S@uQBD!T)Hdpr&zscKYco^^NWwrs@=RJ+nuLR2Q0GUa)v~|F{eXDfz?L%ixncXIqVzr z2?k5k7nA>OT?j7{=@4w%)!K6p$S?5wD@|kG3V<n+cBWz zI-6DVfSXY>X`d%Ui#B)hit#4tY#%E_Ai<%A7wtF?=-_P?2E9N(g6hzS=UM)l*@La@ zV58IU$hBCx*FDBmE!GrOzu4+@%VcNim%~G<^SAN%Md(1Nci73FXg#Z!3e>u`tA{2a z(v`pvYG_VF7eRR@c1ScRlpgg`i2e0KAp2-k=(I0fOQlCco7$bmx`7kzd>6Px8hY%q+A$CW4ulw$E zoRSB$bL?B#ulU7cF*0cPx+m_FE^`AS{uT`zpzi5=?F}5h#rR0_8TB_-E9~j{kg;XE zdpiAiqtjtPao9(J+a$eZ`y2IQ_w;}(sD}N{An;AE2~7f~Ua-ZRXACw1UQBNb;tX8T zfoBy&TB{QobVwK$M0uKE*CC)QJ%CyvPeXE|?4ylVP-}Xztz?9A4CS&w$&V;uTzH0P zQ^O-0<>Q0AuOK3mhVo>Tp*I{-2kb@$ffrP{t6sxrL?vzN^ED)eTx9rOqY@8wkBaY8 z@-8J0kVH8?N9^s;>59ztO{gV*pCe`vNS`(>Ht>X!4jXvO0Xwz8an@1kFCkI%EWk)s z%^PGICkRNXvznrFOHL_tONR-b*6ri+NuD`>#W`(m1O0wV=VPXig76zu8zu}zvY zda<@!DWrJsB*70!x+xYS+SEvg&o74?4d4jY67nGty-&#%N_YnutR2EP8OH%%q;nnT zD^U96>?oFr9n5MuK;X2R*6`@vnS*;ZY&Ar1bbd5Eyj=wK2pZR+VPcO!Xt6>eTmBhct6& zM2liah-}%Bn(n@A7W-OiS%BUOu2pGJLhodG#G@t3?cfF6_jXr))2-@;2aQnl@ zA3VC_P#ALM-V?TEWpW$oM}9r!8DsnbGHl@y4-5!hAR@R2>6?S?hrKq~z9-n$n{4Hi zkkX_K0}B{K?PH8#yRR<5FHe&d7x?=bdO;sR8PwlMNt&$nh+~>qiB!pYr8_sUldq7m z9kq|~$<4$y8`ef-Y~(FUhFS~mJW0yFT*c%>86{u_ zYpuu{1Y)DH=We40OXoSh=ZPqVPsi^xYJps#J}y&oNXSS6{uWQH#0*oN*3I}2#wZ7o zgP`qi%+Usic82xI7s!w{`3{4#6zID^U3F~vp7zzro)H?f%6-EG1lZ5(SZgC*>%fGe zf$`b3_7stL30Q9ySZ^veWGUdllnkM(QbS{u9vUO-|Bo?}t5}%W{*VN(y!daVRXq6p z(_1gd0$`tc_lFxV=}bvMNdXBqo1F0|U6Jjk1KH_pBW?qT6T|{Rr}H~(*iqX~Xa4W; z#M_=?+hydj<%g`jZM(>R@y?sb(9XwpOO+nJk2YDXY5t9drfIXeIxrBM zt3q-GI^Z5FfjF#UvSO1zMHA^za!7UUi}k+NazXbucsPh+anvS#+&U?#{0C?(CK zXE1?lBlxnQ!6JZa`HwH~$4>>dwu zA~^@sk4rEs%&nX9A3{LNdKTr@oN^s_#mTlNhGlSi;<9jTcba&A_Q8irFta;Lz444Q z$N?U#8*s7yj;_=`SAw}<9;I2ha#?cOI1E3&bs=~uF2(Fl0F~x@$HMXOSVDnf$X5YA zbmgY~Mmp>ZgM<;lz#_ zoW&d5!CXL<;CqCZd4;a}zJw~;!FZVKDUA0VL0YA5it%;rYfYY~ws=mlsmKI2W=jYg zZeCoF2yUF#hh{^ctYqf>lubb%quY*_7<7-q;0cUm_ zOLAtJhx+-0e8jbfc@c1F>&aP#=KX#AwEi4)1(urMP{J^S|N0&o zW9cHEW9j-@mltEnLRq$zMjjGH`GGvqI_#bk(?_iuhcahYy6+C)VhgdZ3j4pf?JOhc zy3s+@ylDSkd&4dxRE2=LJusI);#)k}>p=mcHnGN8!0DCiNXTP^uYHNGth`paO-}m0 zZ7yYSHQ_n2Q#nM~q}Figp$egS@<{mN$m-z`k6aP49qwNT$qS%5Y?$X!_7e2#;2!8) zU)u~#@KJX`@-4Kw2gUm_-jOwc1F@;8|LfZ?h_M$ol7bOWm1Ze@N~kI3X}ECS`?l%d z-zpfj{A&Y>fXFB*T-f*aC0z|2h9?y&t5HHdmGaJ^+hO?*)eUW3lx8I3P-u!TH=@ic z(7zK0E#xH-!)M#}qijET2tT^gj?7N2yx#ES9qJ5R!bp}IZCsh$rhEoB1aJp%%Zp1- zn@aSiyA`E87*7hJcu~^zDdM{wnQ#&ic%(3ZCrTNDO>kOkW!#Mz#8YIJ*V_1u0X&&n z!u0-{+wuCbhz<>Viz*Uvu_?TUhcQwJ`XLgIP70GFNRAhXG2eYO;#HDsl;cHXBtVd@ zWLCUZAh#L1MF!oPMjB?G;L_3pcPy44Bxj~2lNZz6V z1CXBy^KYqBa9?qw14lMbUmAiES`I;q6R4SmQ=0`{T73^CIfN=~bpov!IH%Yn2XT*zNSON)A72fn2+Xkv z?nfhVFd6V43PkMV8+OP@x!fb|Kn|IboHjV(5F@wA{e?anwULJ6wI-BuWW<3^=J&w2 zj|X}Q@NIm5v_5ouva%ApZ>(XS6#g&JtqlhL1CLMhw&;I!@Rwe(j!UX3;~D;FHgDtc z7m*CzO};z=WYYoH3|tlg;#9S&it8d=A8`&~j$#dw)t|s(H(#z*T>1K?#l_b<+re79 z)xEwJ1fAuDh0V>)V%aYOg5y7H+C@)nEd2hutaTR-w}_lCG}Rg!8b*_W;@UV1}Z_b!Xd)ymS9rK{J=Lg32h zja6hTHw6A22a2U+R~Eg>wbez}yS8+B>00HocfC>(6?gU8^-I^?a-~dhFbWBbAm>ed zMmqj!M^4aC$)MVGn&PT_PGsc5jp&YjF#StZIJED)BXIxuPfHt;Y6`E)r_^!^NtA5+ zMf%qVxkX<(u7_fAPq5p<+Bpuv(!noHjodO8DNIZM&;+M;kTV1dqUlj8Sc~}H%4f^b z&zI&$X$I|!{9i4S0;AF-e)3Ee7xQ>gA|_#Tlq5P1JWo>YEG0!s=(;$*|D?!Cl=i}Hg#f7!JZjD(< RYXVQovL*ykO)TSS{11xvX<+~W literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f6bf9a766e9feb53d17c5bf29f39466d6e097de GIT binary patch literal 12676 zcmcgy+mj>LS?`-#t;<|?cD$Fd-2<4BuEaeABnO z=Bg!Ed)1b!(<`hNOg^)_PS0I+rEYbLJ#W=Z>!n`7GEAdaUabhlLCfmujMNL=+1}jh zJnF7ryk)E&@jbtU>w^EdU-f5hnX5;iG5lG7?v~-t1%<7}U3>K&Q0Dz3pd1N~;r=-8 z7yP5RKZ+J7g5$f+>PgTR{d+*W2ec)j-3!_=|2SyJLAy_AZv*Xwe-gBlY47*@@AU8W z@57hh?l1fI`)^0BhE_HI0knD`ZS{cvp#Khx^Fcg0<)6lr)9I6Ufbx+4FeneFlvAKQ z;@3f`rX=a)X!``U3;!|Gf+{vhU3y^C#tP?x3+6mE`R zycxDPRX^-(2mXat+}>>TRqw3dw)h$jUR|6|;fzk9h>bmSWR8qIYv0(nM&@EhF*W60 z1A{jkbvr4M-`YHLy|1oBgH}5@ z)9&{MLyWC?trK)4Fio4t-7-vdg-g)+>txnQbU)Gqzt$PeH7=`a7r)ep2CKEc7c^JN*%KcTWA zKO5`3Bcb)4lJT(u;d_2@B|=*TTAuWweMnI~YXuQc?=SVJ}cE@RhtsiebNr z#q$%l*NS_?Zep$^r4Ybtb%8@!J5`j1(bZuPYzOsnVqQ+HVKZ@u1E2kn;UGkbbuCKl z2!+{6%wb|)JIJQFLH|1N)WJ)w7?8RTZ$?Wf3~Sz8u*zoDTC~d6JWJCtm#oJvRm0OZ zn`8%}bslH5h$1jnO=v;Uevo_H=X{;S>-a$!cjB8l`1lqyrsg&wFfs*TGDJv(69wJ% zq|gZg`GdnQLY}X;l>Y=?OOd@jpM7CDec0e|6Vm3Hb%xPU)U+?VF@tJw-5+MU8EHc-p^1mTWuGwbMSZ z=>eJmhl+qo)QF@~tXg3-=tn8c3b@?}k~tM@bRz1ewc+}Dpfr~G&<_)*9|lpqs2)V$ z>Twp2p-3w1vpn6?dEiX1cm#*0^VP`k;mxxO;;V56CC~RQqpb|~K;)4Olj9u$X z)>VgSl{JGb7dL{q6~_v=wU*x{Q>dGIzWF46R6f_m5>bgi*Ch!`gP(TRxfphZWFRG1Yr82#okAIbyOf8)ZI! zMJc(2Gn@EH3^ls;znytvMb|!-^@fQTq_=$LlpPu-k)N-k0vVxoW(noA4-`Q8^WIjc zM&_QgZ|s;imSf1&mMbQvwqwOb%G(yq$`%a9R@t|2JD;$3Y?`C1@1jlIDyd(OD|^*_ z8;`5u8Zc#sap%8qU@=dk zE<#bCOWf;0XJa!?oNJxnIRuKii)?vNAg#;I%OH!&H_T(EYkII4HOmzXUo(Xrmz#~c zwW8{1h%cRHHxIFRn8i<^NIdu_gFwYMWyCTj^(3E;<$sKyNATbr&WODA%<^2-triri zZJhbRr{{4-T!dUJmqiytD{DQx>i8uHa3%3BHlFH-*UtLtQ8bs;pKwj+F7e=%;wES! z(qm|kJ!oX$0~(fw9p=3_^7cz`MeHp;DUHf_R)PCL$DcZuW?!?@NvtqT-1N&_3YkZ} zD*Q+)ov_*7Y=t4b)T9a(Ilil02^7dp@;`B-Si$Ti#X-yW!RW;EJ6$p4iMyft!$Fid zYn@g!R=OGrvfRlbaK0kn!@HA%Z?&-3W6LLiPI1)`EUs67^E?r z<9k?_G9ma}Ck$Fjk8bP>P_lXO%2UwCX-uPU-^npY99kpyaLlO-cr|hLw?zIl=(J%L z(5(lpc&LIDblIH}Cgq0v4KXxC|C%S(f*?qL>Tm?*fd?^h5JCI&8i5!mMNC)LXN;J) zLdQ*xgg7+>8|gT5Y&erRpe_D!aqv9eiE)a=AbvN?wjyl`1(cEVY;-^oF@h%wbOi7C zbU-QzJea3@Yk|md;||BZB48cE;s2_P%r{a-8W^Sg^Ap%e71AtSwICZ_l+5Jyrr+r$ zl~n&UJN}Kt>$Pq)8L;TEr5WEG1j#}K?rxvqk%&|ZUr6cp)Gh!cIthLqLJ^aXG))-7>;E#JOXg6S)! z#<1{`;TWqfmcdI(ABtm!X#FrL!Ijn3cGe@WR6#a-{{s{;BJL49VJ0>ZT%g{xB}Ad- z=SeGoFE1HBzDmt^o8by=!W%zD@G|k#J@77FDS6k&6X|`jXD=?r<-Lkuglh-My3KH@ zDrUFbxFc4Cno|t?3Ru3v8!L{sodhh_Q6vr{ zQGB!1Y__|tC~7vf#S>#krzG4$~D_Af!y_3FEgxZ`etx&J?vUvJp9=^_J_~Sa@@*NPktX&Pij@*oQ|va^Ad+ zUsz&0*r&W5pmBF!X*(uqz|g%izFmR2M65b;_AT+*8R1st8GNqn>Av-5{q8%MGvPBc zDlX!y$LWm~wRo!cVU^SxiaT7lJ4Cw3YkgJ`5n4HvsniDjeiwerCaE13?`J_i0v($z zsJUX)UyspL^W9$S^|X(^#fCi;smHC|8uCVeBkvOa8bG9^eVLD?$i1G|K#J(Y8_35T zGmlxP%ty>etoy9@!QPIckf|Ujx8gwu=R9843jP#o$xMp5cF^r+xZ^kBsrK1YKK&wU zzlu|gyJs#`L=1lLADiF*+_(SppPy_qUgkl-as8gc8C^gj>K2KL;|3VuBG56el+^uj zmF-0X_efLVS{yk-e=i(p2kAyIiu6(OM4EE6vTQ8FX5bq^Wrv+9#z?=z+ zEy*^h8Mg6h=}KaI+(uG~kn0EH1h3`vhOk|1wcA5vcqL$i4H<^WpNZ;Osh5?H;i;Ec z&;r1#!j?c74}04CKEnsZXi`8PpdYDsvwk%1F=$3Ha&L$UQ*)D4>B40}rS9Z5-pS;A ztY`SSzsET?IFe{vK<0kI+MeMO=e8zKw=ipwOFnsr6eLm`S;k!MkQdhZbGftfJ4(5w zp$o(OGG10+VDT#`rm{%+7?TPNa}aZmar7@}mhLV9SMc{`gcA*N_EU~?Sw_%AC&EzQ zfFmPOd^#-ab^$pa*>J#fTdkpl-P(c@V=Wst>X6Hn>W--n0$$&K0#S0a7(Od+&}s2b zz)~_}yX7q!clYFv+t3Nh+Kl@}x{-bDNe~Hl%Gt*VL2;)jeAib0Cf_OE3oQjLg7Ki{ z#%0Ws_smxqTh=bzz7(RD;Op*k#i>0vaN_v)4{E2C*0Eku%)5xRYUI4aTXHIKu;b7g z^Ggv^+AdO$W=UUaY&BLZHc}d*feuq05`F&=FG$>Ez}s1Z&jVv;&fBNVQ>J4Rh)gt< z&F!i%JF|C|`W!#>dHv)p`|GR;qF7{NAm;8R5dIEl8r{v#dj*li zeD8ut&(Z!OZU7zgimTqk{i!{s7O~HJ5{ZAL5G~N?3By&|1Lec13cMl;oQj(YTy{d3fs=iZu>!9qPP@Gp-mSn_5^F7N(-iaIHP_;=Y85h{etxV~=s2dvOp@3Q7|O7-fm6+Gd%1h|LN(R#Jnw3Xy^M;PhzdZ}XiQ z5$dvfIi4Lq-Q~At!ez9q_%^uu9hmny--+k(9fo0h&lloZLCc2(!!%{>|6#ygc7Omx4 zsOZxm1%<8d3k^88{HQiNxrtFVZD z1YROy9Q<)&s)mXo4nY;0joF?KuGrOak=o;i@%9X1F*Z6h-#iEccKH;}=wDG#gCoL$ z_#J1iAHk^>MMgN#mtYb+Ej&Hs8TGFdQC}C-$5GNbMO5QEx8ZN$9wGA;3!a&`R}!|D zL90;f<68BpWoP!iokI9AXHGQaPqFwY3t^Qo zO#ncqrK)3%e|!+tF~*@b2jq4#gMVAV(`W~}3H8CMY*&kFsk<65qW>wk6XX0f-cBnV z(adKF3#QBrt7hJ3RiX9?f_Z@g%UPTbRKI{CRsFdmV&OgFLol&~if>|<2f~Z3e_XeW zb#l!7qTcn}J`MSKf9;bKQnd4UtrLPYmZhD?x9_h_iOvT;Id!{(Z9f+S;T{ThD*a9y zJC$kvC^4f0lT6C^+lFdU4AlzfN*EiM5wclxGNRn&Bh{bJss6z13{JIa!sji zwFk;rKp(sNZwP-Gm^(?fk14|AaTbauGM|)iin!SX6U2%;$0f9J~LrnMDK|*lQonDr_ZCl z&f;YjUtn>Lg*c~NH^|k4BpjxfI9)BW71mCn*x(9E`W;(i420&n!BNJFWuB*_zRAYl zLb0;Kq_TT3?=qIsHxW<{|M^TOh=86D3>_|>??UTLJOg{n37C#B*CCSln*-7k5LB0$ zSN_>V#FYC@ZsI<30e?2Q(r6d)J`Y}Vxug-eOD?A!Do5Nax=yah95NfIgQ0X-_4Xg~ zb(B#}A$yU%1BEe&b;Ta|@&?_kiuz@gga|=|L5rjUU|I*@9Y{3lu-d3&Cmp3X1rID2 zcMDf=k%Vr2PRrgOksz!7Vb-2zv5q3~w9m2@CgoN;?p)IwrXF@nTfG7EMoj0vTVKv5 z)h>4C!qysgFQRr*Kc+Dk` zM*dAD&r$RkosjZGyT^d$pbRpIq1-`EL=p?B+-VBsQ%hn+OJa0wR8-r(8NyS>n^f#o z^(spGzmP^wONTvauyRT7O3FUUl+?{?se2uhI3RVabDJF>fBze2pF|921}YfKpMWVT z`~B;o7DElE14z>US5eclmN?h(CkXXd{IW>gU!!)o#C?UlrdOy#4#G@+OeyGpvA_q@ zY#Pg(#r8OOwQ-8PJ%4k9 zIeHGc;3%2M8~+g^@uJPXird5Zn`6=8d?+#qrZUmVW|Xzg^B9~B(JeD-WHVZjZ6J=_ z0@P5-$f-|Q1NfjFaHsE~AI^g!3L4MZLUN^JEg^iNSSOAOy2A;+=FC#x!(&9jW1IY! zARMM6Ugtzp;Zc-lCB`#BBow~w4W+TMzO$pN&;3=8~^)(jsSo9x; zYp5me^YRB+`5$ZI_Sd$8b}at@(OUC7TmKA;=UMzbi;FB?Wbr8$pJgH5DC0Z{9JI0( zTPvbJ_MW6h7b@}2>A&7S$w5bFQ83tnfs@|?4V~i^ykdE2!JU6>ab~f)IKS|=@*z) literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef37c3341c8da38cfc050b296a5e8e343d97b813 GIT binary patch literal 35369 zcmdUY3z%HTRc7DY&*|wIjimADVOhBzOSUGKMo%lYEL*R!MzZCJY{`#w;C9Z7TP0fS89koXY-Ap{df5FQ~}2oM$sU;<&oV-p@(*bSHLE(G=qBJ2|KWp~-^ zF6{rGx_$f3XxesQzx}?Q`MSF5)Vu1`sdG-9sy^P=mrvmD*7y9?g?GP_Nc;r@?Y{y% zNAYpZWFlcFiV54W&6+W17EOcxWGz`t%6F=mlJ9geBj4F#R=#t^oP75bd*nM`%;P&% z>zykU3t<_3#XgDaFRqdAf#QICubo?0ToT*T;!PoBs5m64Hq31-ZWLI$ zwrOs2aWniGJNs;+c(a|ed+@!*Su=CXiqBSr`meSW?TQaW41f%W^T7{wzoWQi0_uj zCop^K%gN#$;|co>w)JenwwxVTlf@B)Z?%UJ9+vPM5x&jdhVV9LVCK%5oin?TYSf;y z@340~YZQ0eBla6H0(9=Mp2A z{7+erdDU9gt2(YVS*Bl0)0SOTPQ|M&Smnt{cq~t$3Z%2# z>Bf8wXi@d19c4{5tZLn=l@}c4?z8~1<{GwB+nGnjlBG0RcSoBGqxtJx+apHM-&EDQ zS8BI$f}mq+YJSeCduPkabrc}^F*NNATIQ$<^YPEbCyUPseB3J;&;?^L;U!uL+kDQB?Y!L^K?+_{)>kLZ8d+cKu(k%DPlx$#l6*t4d>aI^ zF$UQrkj*j3&8UBieM_wVt$<)9#UQr|WH<)74K3PcZ;wH47swqk$PPfTrecsg1+p^+ z*@b?Lf=cX;^<$5;ac``R`;c$HeISVvxJNlzk6YTW5^!m2wWp%DGP<_s1Yd z1adS6IflMs1;)x56UYNG$Vou3He-;}0(me7c?gg*Sf`!#Oh_-@9D|%inn$p5J88}d zh^nL zkmm*RLJaaE=JX}|y)nplOUdtxmHa)(^}Y7{V`;t*kPq13AA@{Qa{WLo*UKpR2kj5V z()=BK|Tt|$LxO*gZvaAKWP7S4DxY- z{7ek;vp2Nx=LGVJSet(yIQB2?Ux?-UMM?8Zu{57V-}l?U982>n$n~rCD>2Bg0rDyP z*JF_1kTjo;L4FhEjM)Dw2KkIYek%s~ZGrqw4D!1I`PVVXXHoLo?a##^zbBC2k3s$b zy;x`eVXPOQ$N2mk`;TI2{#Zur-^L*SPWt|ZSgt=on!x_kSeh?NntvZl^B)BAA7hX| zL!1BH{!%R0U!a`7d>*`K%fN2@pCJFV0E?mD>BI3n9ocqSu5}5!Flv=noj!Nw@kdWsJ8WmNJYVw;S=-y*5qD%H z8Kk_0rW0ht=pZ%i)S7H=d47spagcYthH^@tns-k6hmL8ebs5{frOKd3b5=z)nyv-E zHCd~|*Dn8tsxoC^M^p|wx9Wl8_wCwMpPwsTZm3x_`8p%~M<(XyE`X&#=Q(8LOArc^ z+)FOUtr9Ztk`1Xju4Zpejh(!!OV$1wRMRMd^C`_aa}Bjn?A4u`EZ1Bo$af;n`$x`# zk72#wXE3ZfJw{eF=BL07MGZRArLK{b>P0Vt6lULx^%NuOPR~zH)*S!HIqi}8lB#wf zX`?;`E(^?7V-DjE9&w~w6&y8=N=f&}-%AF`BKECpch?d&?%D-}Da`>Ps!p>}nJ&5I zxn|8#{_)2`M3T5&5^05zf={nK1b%kcF5&)y%<<~f_^}Cp_e6Ql;eewiGt<$4Rq9pXK?+wq;2+XREE$~CLrP;=Ifa}y&L zBF;TN2LIz@5}0VzU#-qDUFYkmZ|t94Uxt0GOkY1rq~CpFx?Hb2HA{xQS*rrUrCZ%b z2)W0QyK6>tdes%Lht_tQ-t={YH6Q9MH|4k&YK_XQYc)~5>y@V*a`od~(FoHtR+6$_ zM=@hvi^;JVZ@OA9X2wdFvAF%iZ7e)Du?q-fRl?C}mOUVWCT!iy%X4Ji>9TTce;-TI ze#-$}u_@%ZuimW!*D7a9|w?4(011?DF+x?Df|0R#Dms_vR9dw#U@zN1>sQr z+`@&re?P0ATZrY_p+!u0zS(Rj4+HFB6A{_cD|_ydk!4$zFYBdV3sTCd&HHt#MUhAH->yiO}j+cOW^`d$3fdAwZ(`65Y+HKSzXj8fl$q;d~F4QU9ddICD z^PuF)tzZYWT!#eB17Yum3bGDZ&o@yA8vRYx0o>+{m4tu>Cg$bx1j4{zU)DUyYDiB> zD?zg_)UCHhGyT*~qRd;NWx!lj92Nyo4Iq*I!|)WsQ`j+^67JSrOsMN%i9Y$L<|7yg z$MW7#-QYrN0Y(EF2d~N|>iH(C5pBi5m#%9|CXShdG!SgM;g6n1V3p0ITN7nQSb)E4 z^_*i}kiHjvo2@pnMI-Oy=ipt@35=udU4XPmMU?F~v!vr+Tt_$5FN{^B_7IwcW1^poCn=7sF*5Ht#9|L`H zn(lS#>iqqk!pE@o02lsZT>RY%3L851O*VAw^){6JI5st}H0F;zI$y_#d6%md2QrQ5 zGd(aI6g4VXN5FHg&T73mk6o!-gGpE+$_T^kUbort1D%d*Fh^Cip|b3|tZHyC+$d@_ zr{?cIPj#!7wnEMjlX3R3ZbFtX6mPP*_n?0;mUX&Zo7~leY*7`l@Kt#_A$hN(gzoIu zda`qK)hpniqB{LYyElhYq29?s?jiL~Huvwdg!vaf?Nuup|G!vKcRqI=#NqaOae&(3 zPW=0hMG$Ma&AAN6xEo2<6Z&<7;^a#`6JmGAu4$5|L*qPMqpsQEhgk=^a~kv9{T46Kiq26uDUvOKWB zUHeErx_rq&U5;9){sENs{^&UXP!RyulU5bXyrw|i+PXFzCKgj^G;!+TkDh8=w&u$9 z1*}g3xTsjoYZP$-&21NPEGhyBda2r&2ja0em1(Ck+iXDcf6Y?SH{o7O%vdx}pYv;4 znFBFZRV(dH25qw`w_>++^#0y%ZQZ$S4Kwp_tX8OfG}CqZ9A(#&r@InzEK{g7CaaKe zPLyjEQgX*)u6k`@CnR8xnO|*YIDO&FJhz@RbDksc`e$ZH&rB6uV3SO5sX7OF(XF2S zGzxpLgSr}9WIUnT00|}DhIyb~;b7sciRz^8^l-ZDPS>%u`R6D@lr4ta>ICNsgjEV+ zp2Z|699sE1ns#^u%)hin&0GfuSHNb40=Ig?Q4H zfMCraF?2tb^1{^CAOEXUqyGp7d;0?q9|OGujC@pAiVWTh+gFKIN9soAls%u+5fKastp`LA7CRQCo+HOb%3VnP+3i`Y0zj*HF|NbxU_wVUiknr^s#9yVfb5fdr zxN9Nj+6qt>)MXK2;>f0sm(D%z`ujR$8bOX4xYNokc>(Ay?A1;T{IGzr*D9dv&P8Q> z*n(2kQ?9=&-U{#pnEqF5K$xEmc(wdp>lfX!RtMbkaKo-nRzYI5)UBK;>=h1(K!u`? zG^5Kfi0V_9a_F^7>Dm~ONe`wNHK#=R%JrW-FCmB{RB~V;4lC#pVT0jHf?;TpJmNFW zGl1vZ=P_l_GH4W^JJNOg8E7*sk{;$5#UxepJHQI<6VO6qUR{9pvD?I>sdDPj-en$& zQAw@HYp*Tf5mQ<_oKest@;Fe7f>~SG)0Gxtv7nDI{_yX3RKBBbbzSyYs@#Bh5a6e` z?;qXo_#YM3+_JL)878EowM$fHo9{UCX5qG=FTHK_w#nV{1*$cP)_c1_Xkp3XT=7PB zHmj9c9#%9Lpqr^%Iykbk-t?^^eGO=@7&S|)774AS#1c+en&jRf{XA4#a&k?kAfyQdp-&TfCY+?bi(&4Wa@ zS*|#{Lv~b^E{RwSf#!<0Pq^6K`rN8SLn?Q-xUgNK$r5$A4o`_t{w*ln-9{CMkuvhh zl#wxV+C>r`~_Up#dG~~vVx=an%wP=4jHm4!H-&8w* zCl#J0@sHoIT+=pe-vdu8F_VDxpJ^v)wI%L+c^mCGO^k z9k#RtC_9%31yH%Er47+D&Eh6P~9UrKD!%2a`Nxf34%&SX| zt2Sfm%5TA{8Hbrn7xh7zHQ>Y=qC@*2)ID;(Z5dOg#Ez}7Itf!&YO|CliXK3rk=Ein z%q~KhFfOY$#o9Bvue4`$Z)xx79vP%~VXPbRJFRv}1R$X_rW=MAV|@k z3NGxaf*hb?Kn35y1L9L8%COKvw2n!_<|;_R*>l}Olv+y%>axUj((sYIx|4w2bav9& zLuVHqmZV1M?4`pFsQqw?xzkjXO*o$4clyjZ*k<3_+9x zC7G&Tf{1eV^chQf45f^KAxKYI4$_Y1EejUk?jd~hPalN2karv)8MQVbi6e)sF=!R3 z27^T$)Q375mJq9ec|vp4Iz6d@)PU5(!dM`(;}fX^qmpaYSupd{Fx42r%CuY5v%IP3Y&u&8e3c;!^b~QI8PTISfbGkfvj^GRu0@@#a-M8Y8PSo`)r^wQ@D)N+r;rTB-C&E<)*f z*1!42X&-H;TgKraj~*S`^^wDwq1B`@0ZzRK8Mw=qN&)H+&nZ~cFLX1l5L9^5J$8B z&(;&y5)nn#(+A+eoetdPxN5Y&C(T8(1$wke-{`>itf}5;8}%)oIg^~h{V%!c!FWlv zkvBrj#Uw(>Fb}?w=j}J-VZ5XY^Q64=OlB$DGFr)2Dy&oDucd*(ssFcYS}f48W|FnW zW!UH9_(841@QmctVSK6k-~?v98Dunv6=eAh+NTaN`d&E26c;_jXjxxJqFhEbFlJ(y zFGznB(VrvI^YUFHJ7^9WTMd5vn{UM2Z+79UV<=Nr=jy9n*5ViA3w_l(moIcR#zG&U zbCS+cI$am_IATAuY*DXKge_atNdS)GY zVH(vwK%?4op-t^qw8yoan!|%C$Ou9LLM4CO<`XO$bGdoFjz!B65q>#3a-pKb&bVU? zheuEZV&5P*G#%GWlSS>m>>QRw!7Ma9*ukyTI%IW3DW2Sl%D2|QzDJ7`*-^s+!xDr~ zrk2%e7oatL(ZMZQ(+duiYwdakS%pKYa6$8r-zj`#VNbw=WtAu;Jh(H7tMp@8!%bqb zB}yqQxI`)KWe{t6*{h~HM4xTU8tOqWXPZks3||9b_zo~A81v+Ww$VJr{_0Vnv8G&4 z!**H*b#x)ffiuP+I8!ve(K6j+LDz@ zUm@xaAVmVqZ~-6uV~J)AfAbBPht;VwMcXpxftwle+x;LMUYsM5eOJOWz2s8rIfm1N zbU6dr!NkUdmwOKLbaR4ZzG%ENF=R-g(uw*=khP~Ikm^!Bije0D@Q-hig^W zqcZURe`ph)49>$>9ZfP&rP0GY8m_tb!_&^HQSjz7PJIfeaqVRpvaDcCx7#CY)FY@* z$z%`H_aK}gRdK4dm6LxPW723ou(e2k5=YTN#%r*?AT3uXgA`XG?i&bEC-wwM)E8uQ zrP4zX*|#G?v7>@Q(*FO}L7}|aZ){Etn^vmd80dyVk|~-CWES+1<-7%fHVR?FO+g$i z=d9EVLn z^*1GemslXJHbR5AVOjJ;w}=0!shHvn0IacjA$w2-%IEUvS$@MbfrXeSoe*Pm;d z>O<55a%td>YYs8*@KQ7B@HVUIrIs?+5>w{l`is8_a+yFl*IIkc2;oVTaXa!CV);KY z#gYfSo=$4!M~-}pZtpdgQfRGd4O}y)%nevyi|bqclO){hTN%6WLigrD5K-SxLVCYONP%Z;ix96g0*8{6F} zOb`D4>`1Y{3PbEGu$KYK(2%5ey94rNX*v+-JN_OGxx;c6?3a>;dA|;y0T_ecq{ZB#NauH#P3wmsDe6_};0zBL zMOG?iPk^*GT&I{j-#Cx^5)D=C;Q_TaIp?xd9L^vEy%qFkBb#s;O>;GPvH_mZek2@F zy*WzXJZ5dunWe*FP%a(WL`i9iL$ovK3~T$AQ;?a3Io*_7?9(<}QGHaYLF-kn7jxou zJ-e8PCT$9~A?Q%CKTKH)^RLl|snN6+Im4J#3}KMv48xY8*rz>V9;Mi@f=;JnV)RG| z^}rNofL#xZ=p!K{=ak3}7zHzDm7S=wtWKj0V4@qkYw2E+Hzd+$ZtdKtFg;gG8Z1mo-v9HbB zW!d5ssU(#QxcXqIZ%1HdhjWZ8sru9zc9i@{0?jKJ>O0_z8i=@#!sR#uAcqk4KzbnP zX}=~vSGL*7vBXnb7V*%BhPRAKbBO8FB+7>q7q>ov5C}&BD`SoI> zLi0u%sYyonJja2jEeTE;#$t=1z7wI<$6^dISMImLhdp(WERm?8K&&@mlH3d~v0`)> z(xV_eR$>Vwqf;1(iUDrx;d1TIxt z+x11v?Jh}XIxrh<$L zyqGGyjSr6&RYW}DzC}@m;t4q%4B<5(Q3-1S1FroUFKdHI#)`=;^;}E5xW`Cb*@5^p zre^CRQ)XJ}O^F#Yz<_E$SJK4?pGa^L9!Wf!cv4S>IpH5X4VGUnzHp>hm`G`5VJH$p zRgP^9jG7F7A;@SrSVqt}5TmqMWeH{Ejb0oLgd|?43j;c_ruW2Ly{xk^G^KTlpnQ;~ zAx8;IUf9l;nCiWBzMGCv-c{yQkpY|LR1U&J$%Rr&#st*@f^(_+1m(K!MA5FB{AlwFy==ANy&2HOr={8M?{|!`CFP(P3agc z5OZTv@VXUJS1XK;+p19XJd{eJp7%SwD) zTj~WkQ>=yKijiNVA^HZ<<|7rZCH_&^>cxj)_Y-iq8DMaK7=!!4R$|G#mca170I4zg z9ESQT*xy-9u9p_+5K`W=xt2+}@G4lti-XT4mNHCdr#68GUxQMP?DA5UA1tfnQtn#f znu+7?#7s|%syw8_cizrEFB8*BUc-bfLOPx4O(kXuKmf44Eb*E#2!`8ClnNIIZGHC@V-s%M)}rO{@M|kb`-Q2%|yI1)KrRisWcqRv$(N-8)150R7Sv z%GB>AVOLrf|e6G3)q@DawCP#$g3X&K#Qs$qHl@LH9G91`Vl(JqeqYn97OO( z3HUKOKTbym^e5H)7?FE07-~1!eH7nOo7gtSNSnBh*doyhVI6VRU!(pexmT=5l@5iZb)M+OOabG z#fnmTR!M_*w8D1em&Za6keKa(YdVJaa$X4;ymMth{XEmv>8w~$ zV487tDeUdXiA_w9fZB}`S;|-k9w={YL}+95oe@9evri&rg;>TvZpRVbJ-!m-Fn0== z2HZ-wdSbYRx6oQWp;iWbk`Ck}|WjJZ&70u*{JT zN_dp|MV!4YPZ`Iic&SXeveY8Pb~yfRvOp4zqY5a+%GajuPz z_qQ@DT2hE%W7HDC$-W8o(->C&mNVg%4#+TV%oNp&2s8fXcJwN79o$qOM^S=1+}cT9 zpi}u-!UUDbMQMB@vOrNO7)opm)=M`W<5C9^@rvYOJI^>Sl1~G^N=p-%-fS8#@GhHyi2|M|y(3hqaXvte>InJ{j>DEGODzjv_#=zr-0& zlnkEuQiPoW$F2_J8!m9!00$v3k-<#TQlZnQK2y#I2N`|dk3>cFu#*t^=Q(ue5F-5GT)YxzMYnP8B(v27=^|8Zol(kYa`W#`M zNfMuMz`uqF-dAr+Mma^u>X^-35oYHjMFa7gR#LFo6)UM8jBam`dw3y4?Qh66g~UbURovua@1XZDB0+fL>r?&n|T~&eHTre>%vr@N92TzNGJ)e#*lua z=kOFB*SOMD_pj$b0fjKJ{%|xrnxMou({^(I4Pzk%g1uS%t_O}Vz{-N}=|Zj8vl4bD zB!dOwia8f1L`cyOp(W}M=tOi&{RV>Qj1aBWHU%t0lD~fTq#tP$u>8nCr95cl&5ee? zW<{@}DNiFQnG+#aOYlHTU}6LRZjvEPPhvK!4&d8R+bAMJo+s;oftmr+JY|RkxA z00ugz@DQNM>AgsD4uS^Q+W>=}fiN;HfMW@e(Xlficf*S!I08X5GN1+?#R=1lJRQ|q zr0dL0;GT%w>+x?rtMB*l=mrMI5SwX!8$L@UCpfT|D|`M9xzy8MKw_)Adh!_RsZV1t zpb@rl$;i!vx0DHy`S~NnC|PteKflGG%-(Oq*ES8HA(SSHNoexaXBk|vMoH2(bHa$s zRFmehKB+8)FNqJV1N=fXacDkd#FOgxkv24xTS+6OVO3PHTVOhNp|c7=@$^B;qx0TBxU zrs>S4kiX1?zg7f<-5_#!FK}wqyW%NI;eH=ZQLx9)@kuN<@AqmQyAEhIsGgRzlF{eD zi*a(0YiA6S^~PmP#yc4$yw+9vC>IoaWIN-T;)*E=`Xz?;)3{CQm3MNuxIAdWJOJ)~ zouL92vl-5pKluIp`=uke(X@9qQbn&p6i2awk82wsQ)V;IBWBoOn0nJ+1jDnKdC9l} z131FR5-%A~^}o$r%r54_6g{n;Rt9R567_ujH@5N(KGp-ch)h#A+YPOyqbtlL7~TLu7n!!&bu6NE|X7 zmN5iBzXu@D$n_H5=4+9+M-qJp;1H$RV6@a0M<>`7mF(yTx zdjy%M<|inZGR8yYWtxax|Tv2L)Q{42j*8$9&y~D9>Y$zxW?NsvvFoqs~=`g*cI$PUY5N0 zG~2*^V4{7PpVO?xc3TskXE2T7;eq-a_&Y}ma@2GkD3S1o+Ie<|N9M5GnQ2a)hb2I& zu3*L8qwj92Da$OJI1IxWF$4^NZ9}?GJlNeGM9ysCnxdr^HlTitWqD zfUJw+TDh`;rwFjc;eAzHP}GJEY=Wc(-3IW7Emt>v}>I( z#>3{YIcVN)Y)+C%hLSalH5n4Y|wIbgI+x;$S;sepyr2U1NA%= zl&~LxK|wlQs2>F9c0-;UA>cZf zKe9y(5K@T?(+P5u`)WAhF`G z#4XyHx@o6Di}MIg+gTlgNhlQSLY>yMa+7J$rvYzZX6?*6(5HbA_we&2q)|ejdMC6U zfv3 z67=eIUD`X{@PwcQDLxPgyGgt&xGL2l!BPK{&i|(K4|M(!&d6F#Nd#{rDx&&`jo4m7 z!WGdA1-2Zvazbi&aTzqDmn7;)*TWX!YyLjP?jK@*u2HzIvl{Xqh`C$yDaKS03pu~Q|$ zUg3~Y&6P3zZTeSyXn!F0kykx|U8Z^?36gFM!qc|Dz&eK+v~DB_BF0tM+-{1rZI!0{ zdeoHV6Ols!>+xAORyKy6CU0alLm!wEukDAz=&C_RD{Mx}Un8%|?9?t42)bQDs5hW+ zH1t+}X)6)fi9#^MeGAwbLoCU0NT5)6^$6a9gfe))5w$o84Xl_PJczg;OM6t<1&9y< zJ|S7B)fP13*e2O=RE8fG3oZEvR*w3wvQ({Lpi$MJQAnz-#7<5V|M-nG+Mb721R5{j zm3U|3o#3$rS8>aO_=nwU#mPM}FB;7%vOWdO)A~|{s~dVO$=cHl=s9c{WNXR1l39D7 z+{{2JjFB+17Xf0%O(@(j$T)TF%@P=tmO%m@FvFy1tqj@2@ZClqzPBOu_Sqyh z(QWqaIvkaB2jX@}+);Z3aR^6oZ;bN2&At5E0n}%zrxAltICCdAB>=;3 z*oEO!5`M-$9flv2@JsfCVfY~lFWC=;;WHBc1j>Yuu&HhBjR>%ZCHC{Un1@a19BAb^ z&1_6e$ZB1N`wVhHnyCp!jNU}r}0M9cDLBCE%>!L{bXJFtmjARP=V@aXju-~?b{hhk5{I*+r<=r zcEeAP=>UI4qX>Hv`56mgs{GVAvJ_!gqJP>VHHM&{x&Zb`ZXXmg<0bMx#q_v*;YTZU za7@1BqG0#{wx6z#AzN`h`Sg`H-7;5`SYCQbpIMQMW1>`5kbWZ-bL164bY4j0>LE#S`WXD1PEqmmt z{@!u<*1e~VkA{xLp!x8Q1{uAZP-(Fb*PJ^)R}q*MRm`=22?XuaTAm`VP-9bl0U5+@ zBFMF$m&M-jF+%SEk%#{({E*jS6`*$>;Xh=D7+8s+pCwlf16DD*cxKI7RH3hAbiJw7?Z96<$L0x*=slgMT2w>@XViZB$I;{fR zq0M3}AR(NdAr8l<62if1EDj;W;oJqst7MAI88MCCKwr8AYh00HdDVklzdNl>Q+vx$ z3gg0uKt7Gu1jdEHs6c|2hV%>*IG^!cf{}LBb9OUhE zyI;DEV5^)HY9S7={s&=)!%blxQyQ!F_ZGZKazhzWpSEU^0t@o`C}mTirW&0HDVR5|Q&|`j*eOT93?| zI?aS;n~(}pMt^K7|3u*w8H0$-2O}pjWhkaZwxd}{t_aKc6kZ8H!?QsFKOO;d3_Ny^ z$n`;=q0a znhMWHgN)WU1}WZc3zB#rE9k{7NoNX|_i&Rr=;!{_1~@LP$a`?^fRh6LHXM&2y!P$2 zLmbw7>D&jWZU1yX0WCCn#XtsYxk=zcNvFGaa9Bm@gxM#lq}`*xH9jJvH)M(y0_u^4 za01`%-5jI#xNO1t(t0ZD=3u}`Ju^nc#E5_egqBLKf-pu#1?Y1RTYZ=giv=FB8v(An z@N!~FMP1b|YT(gi-RD)GsS$nbZuX7%%Tsu&8~wp${R;IcQE#{iIOAkh+)+3oHK*ac z^aV(d-n;->;iazf@(LEg3(%6${0Eo>dYn~vvja!z$SfRTAJQ;hz#i8_D0MNRu}_eq zlqnNc%)-E{)VMGMx#$>D=!xU*Cb{nJ^du(=5r56Ho_Gz^e4MGr=sZA2+VD5+5`w<) ziJywY2ZTC^(XLSHV&k{7t_=5>41o8LPBNayQkHbp8%>1{)Lq!L)l2|2_ zUZ2NV22Oo5v=)H?1`CnrlhD^?phJXFwmN{c>3Rxo2D(CAGf<}xmxb02*A4*9f&FDE zIJsWTwz4P%r%$0FY?fspO%HNr<>UqDPSEJW%1j81#v?SEOhbe;f>=l)%OtO7N$8Mh z+GMOaJwa9;FNHQud-m^B_nN{8wPKnl4ul)$W?y4I*JNEx!w< z)#`tT_H+hazW~tNzFFs&7LIb%smn+}xT_&N;^LsGCxljoUS_4g>Y*ozl{A;0r> znyYX~7|H5)!N0OA8%QInAYn`|1?r4%?jSC&xaJHs!Q7irL1bFNO`6IQ>Cq%Sdf5S8 zWG}@%mQygO0K0_4vBQgLS|CIH`$ohYOX={~62CIfOG(T_G7W__4lAL&maCN7y7J;? zSS5=se@JpnyizT`V{?zw;n7Wy;*WA*tx;(a7`_5hmz^n`G+r@ zJ@OHA>T{pv-`^aGezNJy?KjwkLPxZK6G&Q+#HUR+VIPDuVOYfAi~`T5QiKI^!XzD* zF(ywET3GqfMDPryrh+ug7s@Vd2%2vD-m8Ao9y?%4?!^B&s!%65=#SB%JtSu1etgL@ z!(}M=_%5X?(s?VLx50t^kbXSfZBx73S&ooA7BSnwQhjdhuhFA%qcJ6iy3O^w--Q6_ zmGlYAjMh+BJnj)d)tN=}iVdm-^(I13@^ngb@#GhXxSjCfjvrjavbba{4t)QzxD-t7 zZ+{|jabtK&ifdJGXK%AGe=L z%84Y-r*1<`xJ!wiE=Wtf+K=*Y97C4FR|7&u0+-@}yOf^%p>P~mSkzBMTo(6Q;*Tu; z_oH+h_EaQcIR9P;tbz1|6ZICx3Zn>ns^%ipSXZWTMvPGuNRieDJQaj7_R))@^z*V0 z_7YkpA`TCFz(%o^fjX3yiysE(k+Cs$t@Kg`r_s2qb4@Pi0GeJEngKMW%S4(iLP!aY zE<<&qnV@|a+LPOqfUQ>#Y_U*!8v1j)=S^4=NDVtHgpyD=ehqcvMkT{631{LaTA9T| zA3bOTY_(cFunxnGOc1ghpm{(GfaVDW9OMS@KD4?QrS$Sf`J&`votvN)hEH1+A(BGTDDSWS#Ep1vF0B)vA=EYIq))o*=?yr6R51nv;~#ipM(jL{-a5h(vL8T@fC8qx>k++LYeTo0UD#820E!SoeOk? zCdhFCX@aH>Bo;UYs}S&EVzC^R5&M-**dy>!`F}acV><-lg;jk?a|0#Z0k!e40^YMI zfs?~KsX-2IMTgys%0nFQ+Y1!7MUAh@kdDMPQ829S0 zeY`zhZ_EV)7piCs%=BTaPswHP4Um0&4a;fDV;L)UAvKA*e6Crt?Fl}n2+QIDVUtZ> zf=a^L$Q;6dKiH509;>SGi(QM4H;u5SlNJCBCS>w$rBQjT$pLoqWEyhGf6WmMi73H|AoZwo=CR_M&dK#rCc$c1a}mCTrGz$xCOM_K6LXh;nq=$&T*VNfEZ*62-?4swSz zpM5`ND7np_!s615N86RKnDp_@7``Tx`?kclCC=w=Z>Q*oYE%Ey;O_<>|MQ0_+2!&BR{xs*KKL{FP5Czr+&r)&w 0: + label_indices = [[] for _ in range(max(self.labels)+1)] + for i,l in enumerate(self.labels): + label_indices[l] += [i] + label_indices = np.asarray(label_indices) + + # randomly grab 500 elements of each class + np.random.seed(validate_seed) + self.val_indices = [] + for l_i in label_indices: + self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) + + if self.train=='validate': + self.data = self.data[self.val_indices] + self.labels = list(np.asarray(self.labels)[self.val_indices]) + + self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + elif self.train: + print(np.shape(self.data)) + if self.val_split > 0: + self.data = np.delete(self.data,self.val_indices,axis=0) + self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) + + self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + else: + f = self.test_list[0][0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + self.data = entry['data'] + if 'labels' in entry: + self.labels = entry['labels'] + else: + self.labels = entry['fine_labels'] + fo.close() + self.data = self.data.reshape((10000, 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.labels[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + +class CIFAR100(CIFAR10): + base_folder = 'cifar-100-python' + url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] diff --git a/text2image/BigGAN_utils/inception_tf13.py b/text2image/BigGAN_utils/inception_tf13.py new file mode 100644 index 0000000..f43ecac --- /dev/null +++ b/text2image/BigGAN_utils/inception_tf13.py @@ -0,0 +1,138 @@ +''' Tensorflow inception score code +Derived from https://github.com/openai/improved-gan +Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py +THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE + +To use this code, run sample.py on your model with --sample_npz, and then +pass the experiment name in the --experiment_name. +This code also saves pool3 stats to an npz file for FID calculation +''' +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import sys +import tarfile +import math +from tqdm import tqdm, trange +from argparse import ArgumentParser + +import numpy as np +from six.moves import urllib +import tensorflow as tf + +MODEL_DIR = '' +DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' +softmax = None + +def prepare_parser(): + usage = 'Parser for TF1.3- Inception Score scripts.' + parser = ArgumentParser(description=usage) + parser.add_argument( + '--experiment_name', type=str, default='', + help='Which experiment''s samples.npz file to pull and evaluate') + parser.add_argument( + '--experiment_root', type=str, default='samples', + help='Default location where samples are stored (default: %(default)s)') + parser.add_argument( + '--batch_size', type=int, default=500, + help='Default overall batchsize (default: %(default)s)') + return parser + + +def run(config): + # Inception with TF1.3 or earlier. + # Call this function with list of images. Each of elements should be a + # numpy array with values ranging from 0 to 255. + def get_inception_score(images, splits=10): + assert(type(images) == list) + assert(type(images[0]) == np.ndarray) + assert(len(images[0].shape) == 3) + assert(np.max(images[0]) > 10) + assert(np.min(images[0]) >= 0.0) + inps = [] + for img in images: + img = img.astype(np.float32) + inps.append(np.expand_dims(img, 0)) + bs = config['batch_size'] + with tf.Session() as sess: + preds, pools = [], [] + n_batches = int(math.ceil(float(len(inps)) / float(bs))) + for i in trange(n_batches): + inp = inps[(i * bs):min((i + 1) * bs, len(inps))] + inp = np.concatenate(inp, 0) + pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp}) + preds.append(pred) + pools.append(pool) + preds = np.concatenate(preds, 0) + scores = [] + for i in range(splits): + part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0)) + # Init inception + def _init_inception(): + global softmax, pool3 + if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(MODEL_DIR, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') + tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) + with tf.gfile.FastGFile(os.path.join( + MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='') + # Works with an arbitrary minibatch size. + with tf.Session() as sess: + pool3 = sess.graph.get_tensor_by_name('pool_3:0') + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o._shape = tf.TensorShape(new_shape) + w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] + logits = tf.matmul(tf.squeeze(pool3), w) + softmax = tf.nn.softmax(logits) + + # if softmax is None: # No need to functionalize like this. + _init_inception() + + fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name']) + print('loading %s ...'%fname) + ims = np.load(fname)['x'] + import time + t0 = time.time() + inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10) + t1 = time.time() + print('Saving pool to numpy file for FID calculations...') + np.savez('%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name']), **{'pool_mean': np.mean(pool_activations,axis=0), 'pool_var': np.cov(pool_activations, rowvar=False)}) + print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std)) +def main(): + # parse command line and run + parser = prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/inception_utils.py b/text2image/BigGAN_utils/inception_utils.py new file mode 100644 index 0000000..373d3cd --- /dev/null +++ b/text2image/BigGAN_utils/inception_utils.py @@ -0,0 +1,310 @@ +''' Inception utilities + This file contains methods for calculating IS and FID, using either + the original numpy code or an accelerated fully-pytorch version that + uses a fast newton-schulz approximation for the matrix sqrt. There are also + methods for acquiring a desired number of samples from the Generator, + and parallelizing the inbuilt PyTorch inception network. + + NOTE that Inception Scores and FIDs calculated using these methods will + *not* be directly comparable to values calculated using the original TF + IS/FID code. You *must* use the TF model if you wish to report and compare + numbers. This code tends to produce IS values that are 5-10% lower than + those obtained through TF. +''' +import numpy as np +from scipy import linalg # For numpy FID +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter as P +from torchvision.models.inception import inception_v3 + + +# Module that wraps the inception network to enable use with dataparallel and +# returning pool features and logits. +class WrapInception(nn.Module): + def __init__(self, net): + super(WrapInception,self).__init__() + self.net = net + self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), + requires_grad=False) + self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), + requires_grad=False) + def forward(self, x): + # Normalize x + x = (x + 1.) / 2.0 + x = (x - self.mean) / self.std + # Upsample if necessary + if x.shape[2] != 299 or x.shape[3] != 299: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) + # 299 x 299 x 3 + x = self.net.Conv2d_1a_3x3(x) + # 149 x 149 x 32 + x = self.net.Conv2d_2a_3x3(x) + # 147 x 147 x 32 + x = self.net.Conv2d_2b_3x3(x) + # 147 x 147 x 64 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 73 x 73 x 64 + x = self.net.Conv2d_3b_1x1(x) + # 73 x 73 x 80 + x = self.net.Conv2d_4a_3x3(x) + # 71 x 71 x 192 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 35 x 35 x 192 + x = self.net.Mixed_5b(x) + # 35 x 35 x 256 + x = self.net.Mixed_5c(x) + # 35 x 35 x 288 + x = self.net.Mixed_5d(x) + # 35 x 35 x 288 + x = self.net.Mixed_6a(x) + # 17 x 17 x 768 + x = self.net.Mixed_6b(x) + # 17 x 17 x 768 + x = self.net.Mixed_6c(x) + # 17 x 17 x 768 + x = self.net.Mixed_6d(x) + # 17 x 17 x 768 + x = self.net.Mixed_6e(x) + # 17 x 17 x 768 + # 17 x 17 x 768 + x = self.net.Mixed_7a(x) + # 8 x 8 x 1280 + x = self.net.Mixed_7b(x) + # 8 x 8 x 2048 + x = self.net.Mixed_7c(x) + # 8 x 8 x 2048 + pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2) + # 1 x 1 x 2048 + logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1)) + # 1000 (num_classes) + return pool, logits + + +# A pytorch implementation of cov, from Modar M. Alfadly +# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 +def torch_cov(m, rowvar=False): + '''Estimate a covariance matrix given data. + + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. + + Returns: + The covariance matrix of the variables. + ''' + if m.dim() > 2: + raise ValueError('m has more than 2 dimensions') + if m.dim() < 2: + m = m.view(1, -1) + if not rowvar and m.size(0) != 1: + m = m.t() + # m = m.type(torch.double) # uncomment this line if desired + fact = 1.0 / (m.size(1) - 1) + m -= torch.mean(m, dim=1, keepdim=True) + mt = m.t() # if complex: mt = m.t().conj() + return fact * m.matmul(mt).squeeze() + + +# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji +# https://github.com/msubhransu/matrix-sqrt +def sqrt_newton_schulz(A, numIters, dtype=None): + with torch.no_grad(): + if dtype is None: + dtype = A.type() + batchSize = A.shape[0] + dim = A.shape[1] + normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() + Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); + I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + for i in range(numIters): + T = 0.5*(3.0*I - Z.bmm(Y)) + Y = Y.bmm(T) + Z = T.bmm(Z) + sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) + return sA + + +# FID calculator from TTUR--consider replacing this with GPU-accelerated cov +# calculations using torch? +def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + Taken from https://github.com/bioinf-jku/TTUR + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + print('wat') + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + return out + + +def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Pytorch implementation of the Frechet Distance. + Taken from https://github.com/bioinf-jku/TTUR + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 + covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() + out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) + - 2 * torch.trace(covmean)) + return out + + +# Calculate Inception Score mean + std given softmax'd logits and number of splits +def calculate_inception_score(pred, num_splits=10): + scores = [] + for index in range(num_splits): + pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :] + kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))) + kl_inception = np.mean(np.sum(kl_inception, 1)) + scores.append(np.exp(kl_inception)) + return np.mean(scores), np.std(scores) + + +# Loop and run the sampler and the net until it accumulates num_inception_images +# activations. Return the pool, the logits, and the labels (if one wants +# Inception Accuracy the labels of the generated class will be needed) +def accumulate_inception_activations(sample, net, num_inception_images=50000): + pool, logits, labels = [], [], [] + while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images: + with torch.no_grad(): + images, labels_val = sample() + pool_val, logits_val = net(images.float()) + pool += [pool_val] + logits += [F.softmax(logits_val, 1)] + labels += [labels_val] + return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0) + + +# Load and wrap the Inception model +def load_inception_net(parallel=False): + inception_model = inception_v3(pretrained=True, transform_input=False) + inception_model = WrapInception(inception_model.eval()).cuda() + if parallel: + print('Parallelizing Inception module...') + inception_model = nn.DataParallel(inception_model) + return inception_model + + +# This produces a function which takes in an iterator which returns a set number of samples +# and iterates until it accumulates config['num_inception_images'] images. +# The iterator can return samples with a different batch size than used in +# training, using the setting confg['inception_batchsize'] +def prepare_inception_metrics(dataset, parallel, no_fid=False): + # Load metrics; this is intentionally not in a try-except loop so that + # the script will crash here if it cannot find the Inception moments. + # By default, remove the "hdf5" from dataset + dataset = dataset.strip('_hdf5') + data_mu = np.load(dataset+'_inception_moments.npz')['mu'] + data_sigma = np.load(dataset+'_inception_moments.npz')['sigma'] + # Load network + net = load_inception_net(parallel) + def get_inception_metrics(sample, num_inception_images, num_splits=10, + prints=True, use_torch=True): + if prints: + print('Gathering activations...') + pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images) + if prints: + print('Calculating Inception Score...') + IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits) + if no_fid: + FID = 9999.0 + else: + if prints: + print('Calculating means and covariances...') + if use_torch: + mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False) + else: + mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False) + if prints: + print('Covariances calculated, getting FID...') + if use_torch: + FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda()) + FID = float(FID.cpu().numpy()) + else: + FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma) + # Delete mu, sigma, pool, logits, and labels, just in case + del mu, sigma, pool, logits, labels + return IS_mean, IS_std, FID + return get_inception_metrics \ No newline at end of file diff --git a/text2image/BigGAN_utils/layers.py b/text2image/BigGAN_utils/layers.py new file mode 100644 index 0000000..55aaab1 --- /dev/null +++ b/text2image/BigGAN_utils/layers.py @@ -0,0 +1,459 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d + + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2,2]) + g = F.max_pool2d(self.g(x), [2,2]) + # Perform reshapes + theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m **2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn',): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s +=' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size= output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1,-1,1,1) + bias = self.bias.view(1,-1,1,1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.out_channels) + self.conv2 = self.which_conv(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None,): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + +# dogball \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl b/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl new file mode 100644 index 0000000..cfa578f --- /dev/null +++ b/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl @@ -0,0 +1,68 @@ +{"itr": 2000, "IS_mean": 2.806771755218506, "IS_std": 0.019480662420392036, "FID": 173.76484159711126, "_stamp": 1551403232.0425167} +{"itr": 4000, "IS_mean": 4.962374687194824, "IS_std": 0.07276841998100281, "FID": 113.86730514283107, "_stamp": 1551422228.743057} +{"itr": 6000, "IS_mean": 6.939817905426025, "IS_std": 0.11417163163423538, "FID": 101.63548498447199, "_stamp": 1551457139.3400874} +{"itr": 8000, "IS_mean": 8.142985343933105, "IS_std": 0.11931543797254562, "FID": 92.0014385772705, "_stamp": 1551476217.2409613} +{"itr": 10000, "IS_mean": 10.355518341064453, "IS_std": 0.09094739705324173, "FID": 83.58068997965364, "_stamp": 1551494854.2419689} +{"itr": 12000, "IS_mean": 11.288347244262695, "IS_std": 0.14952820539474487, "FID": 80.98066299357106, "_stamp": 1551513232.5049698} +{"itr": 14000, "IS_mean": 11.755794525146484, "IS_std": 0.17969024181365967, "FID": 76.80603924280956, "_stamp": 1551531425.150371} +{"itr": 18000, "IS_mean": 13.65534496307373, "IS_std": 0.11151058971881866, "FID": 65.95736694335938, "_stamp": 1551588271.9177916} +{"itr": 20000, "IS_mean": 14.817827224731445, "IS_std": 0.23588882386684418, "FID": 61.32061767578125, "_stamp": 1551606713.6567464} +{"itr": 22000, "IS_mean": 17.16551399230957, "IS_std": 0.19506946206092834, "FID": 53.387969970703125, "_stamp": 1551624876.6513028} +{"itr": 24000, "IS_mean": 19.60654067993164, "IS_std": 0.5591856837272644, "FID": 46.5386962890625, "_stamp": 1551642822.6126688} +{"itr": 26000, "IS_mean": 21.74416732788086, "IS_std": 0.2850531041622162, "FID": 41.595001220703125, "_stamp": 1551663522.6019194} +{"itr": 28000, "IS_mean": 23.923612594604492, "IS_std": 0.41587772965431213, "FID": 37.894744873046875, "_stamp": 1551681794.6567173} +{"itr": 30000, "IS_mean": 25.569377899169922, "IS_std": 0.3333457112312317, "FID": 35.49310302734375, "_stamp": 1551699773.7080302} +{"itr": 32000, "IS_mean": 26.867944717407227, "IS_std": 0.5968036651611328, "FID": 33.4849853515625, "_stamp": 1551717623.887933} +{"itr": 34000, "IS_mean": 28.719074249267578, "IS_std": 0.5698027014732361, "FID": 31.375518798828125, "_stamp": 1551735411.1578612} +{"itr": 36000, "IS_mean": 30.587574005126953, "IS_std": 0.5044271349906921, "FID": 29.432281494140625, "_stamp": 1551783380.6357439} +{"itr": 38000, "IS_mean": 32.08299255371094, "IS_std": 0.49342143535614014, "FID": 28.099456787109375, "_stamp": 1551801179.6495197} +{"itr": 40000, "IS_mean": 34.24657440185547, "IS_std": 0.7709177732467651, "FID": 26.53802490234375, "_stamp": 1551818775.171794} +{"itr": 42000, "IS_mean": 35.891212463378906, "IS_std": 0.7036871314048767, "FID": 25.03021240234375, "_stamp": 1551836329.6873965} +{"itr": 44000, "IS_mean": 38.184898376464844, "IS_std": 0.32996198534965515, "FID": 23.4940185546875, "_stamp": 1551897864.911537} +{"itr": 46000, "IS_mean": 40.239479064941406, "IS_std": 0.7761151194572449, "FID": 22.53167724609375, "_stamp": 1551915406.4840703} +{"itr": 48000, "IS_mean": 41.46656036376953, "IS_std": 1.1031498908996582, "FID": 21.5338134765625, "_stamp": 1551932899.6074848} +{"itr": 50000, "IS_mean": 43.31670379638672, "IS_std": 0.7796809077262878, "FID": 20.53253173828125, "_stamp": 1551950390.345334} +{"itr": 52000, "IS_mean": 45.1517333984375, "IS_std": 1.2925242185592651, "FID": 19.656646728515625, "_stamp": 1551967838.1501615} +{"itr": 54000, "IS_mean": 47.638771057128906, "IS_std": 1.0689665079116821, "FID": 18.898162841796875, "_stamp": 1552044534.5349634} +{"itr": 56000, "IS_mean": 48.87520217895508, "IS_std": 1.1317559480667114, "FID": 18.1248779296875, "_stamp": 1552061763.3080354} +{"itr": 58000, "IS_mean": 49.40987014770508, "IS_std": 1.1866596937179565, "FID": 17.751922607421875, "_stamp": 1552078939.9828825} +{"itr": 60000, "IS_mean": 51.051334381103516, "IS_std": 1.2281248569488525, "FID": 17.19964599609375, "_stamp": 1552096167.889482} +{"itr": 62000, "IS_mean": 52.0235481262207, "IS_std": 0.5391153693199158, "FID": 16.62115478515625, "_stamp": 1552113417.9520617} +{"itr": 64000, "IS_mean": 53.868492126464844, "IS_std": 1.327082633972168, "FID": 16.237335205078125, "_stamp": 1552142961.09602} +{"itr": 66000, "IS_mean": 54.978721618652344, "IS_std": 0.9502049088478088, "FID": 15.81170654296875, "_stamp": 1552162403.2232807} +{"itr": 68000, "IS_mean": 55.73248291015625, "IS_std": 1.0323851108551025, "FID": 15.545623779296875, "_stamp": 1552181112.676657} +{"itr": 70000, "IS_mean": 56.78422927856445, "IS_std": 1.211003303527832, "FID": 15.28369140625, "_stamp": 1552199498.887533} +{"itr": 72000, "IS_mean": 57.972999572753906, "IS_std": 0.8668608665466309, "FID": 14.86395263671875, "_stamp": 1552217782.2738616} +{"itr": 74000, "IS_mean": 58.845054626464844, "IS_std": 1.4297977685928345, "FID": 14.620635986328125, "_stamp": 1552251085.1781816} +{"itr": 76000, "IS_mean": 59.60982131958008, "IS_std": 0.9095696210861206, "FID": 14.360198974609375, "_stamp": 1552270214.9345307} +{"itr": 78000, "IS_mean": 60.71195602416992, "IS_std": 0.960899829864502, "FID": 14.07183837890625, "_stamp": 1552288697.1580262} +{"itr": 80000, "IS_mean": 61.772125244140625, "IS_std": 0.6913255453109741, "FID": 13.781585693359375, "_stamp": 1552307170.0280282} +{"itr": 82000, "IS_mean": 62.98079299926758, "IS_std": 1.4735801219940186, "FID": 13.55389404296875, "_stamp": 1552325252.8553352} +{"itr": 84000, "IS_mean": 64.95240783691406, "IS_std": 0.9018951654434204, "FID": 13.231689453125, "_stamp": 1552344135.3111835} +{"itr": 86000, "IS_mean": 65.13968658447266, "IS_std": 0.8772205114364624, "FID": 13.176849365234375, "_stamp": 1552362429.6782444} +{"itr": 88000, "IS_mean": 65.84476470947266, "IS_std": 1.167534351348877, "FID": 12.87078857421875, "_stamp": 1552380560.7988124} +{"itr": 90000, "IS_mean": 67.41099548339844, "IS_std": 1.6899267435073853, "FID": 12.586517333984375, "_stamp": 1552398550.2060475} +{"itr": 92000, "IS_mean": 68.63685607910156, "IS_std": 1.9431978464126587, "FID": 12.49505615234375, "_stamp": 1552430781.6406457} +{"itr": 94000, "IS_mean": 70.09907531738281, "IS_std": 1.0715738534927368, "FID": 12.047607421875, "_stamp": 1552449001.1950285} +{"itr": 96000, "IS_mean": 70.34623718261719, "IS_std": 1.7962944507598877, "FID": 11.896697998046875, "_stamp": 1552466989.3587568} +{"itr": 98000, "IS_mean": 71.08210754394531, "IS_std": 1.458209753036499, "FID": 11.73046875, "_stamp": 1552484800.7138846} +{"itr": 100000, "IS_mean": 72.24256896972656, "IS_std": 1.3259714841842651, "FID": 11.7386474609375, "_stamp": 1552502538.0269725} +{"itr": 102000, "IS_mean": 73.19488525390625, "IS_std": 1.3439149856567383, "FID": 11.50494384765625, "_stamp": 1552523284.4514356} +{"itr": 104000, "IS_mean": 73.38243103027344, "IS_std": 1.4162707328796387, "FID": 11.374542236328125, "_stamp": 1552541012.0651608} +{"itr": 106000, "IS_mean": 74.95563507080078, "IS_std": 1.089124083518982, "FID": 11.10479736328125, "_stamp": 1552558577.7458107} +{"itr": 108000, "IS_mean": 76.42997741699219, "IS_std": 1.9282453060150146, "FID": 10.998870849609375, "_stamp": 1552576111.9480467} +{"itr": 110000, "IS_mean": 76.89225769042969, "IS_std": 1.4771150350570679, "FID": 10.847015380859375, "_stamp": 1552593659.445132} +{"itr": 112000, "IS_mean": 78.04684448242188, "IS_std": 1.4850096702575684, "FID": 10.772552490234375, "_stamp": 1552616479.5201895} +{"itr": 114000, "IS_mean": 79.67677307128906, "IS_std": 2.0147368907928467, "FID": 10.528045654296875, "_stamp": 1552633850.9315467} +{"itr": 116000, "IS_mean": 79.8828125, "IS_std": 0.978247344493866, "FID": 10.626068115234375, "_stamp": 1552651198.9012825} +{"itr": 118000, "IS_mean": 79.95381164550781, "IS_std": 1.8608143329620361, "FID": 10.46771240234375, "_stamp": 1552668560.4420238} +{"itr": 120000, "IS_mean": 82.37217712402344, "IS_std": 1.8909310102462769, "FID": 10.259033203125, "_stamp": 1552749673.4319007} +{"itr": 122000, "IS_mean": 83.49666595458984, "IS_std": 2.38446044921875, "FID": 9.996185302734375, "_stamp": 1552766698.2706933} +{"itr": 124000, "IS_mean": 83.05189514160156, "IS_std": 1.8844469785690308, "FID": 10.164398193359375, "_stamp": 1552783762.891172} +{"itr": 126000, "IS_mean": 84.27763366699219, "IS_std": 0.9329544901847839, "FID": 10.03509521484375, "_stamp": 1552800953.5724175} +{"itr": 128000, "IS_mean": 85.84852600097656, "IS_std": 2.2698562145233154, "FID": 9.91644287109375, "_stamp": 1552818112.227726} +{"itr": 130000, "IS_mean": 87.356689453125, "IS_std": 2.0958640575408936, "FID": 9.771148681640625, "_stamp": 1552837539.995247} +{"itr": 132000, "IS_mean": 88.72562408447266, "IS_std": 1.7551432847976685, "FID": 9.8258056640625, "_stamp": 1552859685.9305944} +{"itr": 134000, "IS_mean": 88.0631103515625, "IS_std": 1.8199039697647095, "FID": 9.957183837890625, "_stamp": 1552880037.5408435} +{"itr": 136000, "IS_mean": 91.50938415527344, "IS_std": 1.9926033020019531, "FID": 9.876556396484375, "_stamp": 1552899854.652669} +{"itr": 138000, "IS_mean": 93.09217834472656, "IS_std": 2.3062736988067627, "FID": 9.908477783203125, "_stamp": 1552921580.958927} \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/compare_IS.m b/text2image/BigGAN_utils/logs/compare_IS.m new file mode 100644 index 0000000..c72b079 --- /dev/null +++ b/text2image/BigGAN_utils/logs/compare_IS.m @@ -0,0 +1,89 @@ +clc +clear all +close all +fclose all; + + + +%% Get All logs and sort them +s = {}; +d = dir(); +j = 1; +for i = 1:length(d) + if any(strfind(d(i).name,'.jsonl')) + s = [s; d(i).name]; + end +end + + +j = 1; +for i = 1:length(s) + fname = s{i,1}; + % Check if the Inception metrics log exists, and if so, plot it + [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl'); + s{i,2} = itr; + s{i,3} = IS; + s{i,4} = FID; + s{i,5} = max(IS); + s{i,6} = min(FID); + s{i,7} = t; +end +% Sort by Inception Score? +[IS_sorted, IS_index] = sort(cell2mat(s(:,5))); +% Cutoff inception scores below a certain value? +threshold = 22; +IS_index = IS_index(IS_sorted > threshold); + +% Sort by FID? +[FID_sorted, FID_index] = sort(cell2mat(s(:,6))); +% Cutoff also based on IS? +% threshold = 0; +FID_index = FID_index(IS_sorted > threshold); + + + +%% Plot things? +cc = hsv(length(IS_index)); +legend1 = {}; +legend2 = {}; +make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations +for i=1:length(IS_index) + legend1 = [legend1; s{IS_index(i), 1}]; + figure(1) + plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2) + hold on; + xlabel('itr'); ylabel('IS'); + grid on; + if make_axis + axis([0,1e6,0,80]); % 50% grid on; + end + legend(legend1,'Interpreter','none') + %pause(1) % Turn this on to animate stuff + legend2 = [legend2; s{IS_index(i), 1}]; + figure(2) + plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2) + hold on; + xlabel('itr'); ylabel('FID'); + j = j + 1; + grid on; + if make_axis + axis([0,1e6,0,50]);% grid on; + end + legend(legend2, 'Interpreter','none') + +end + +%% Quick script to plot IS versus timesteps +if 0 + figure(3); + this_index=4; + subplot(2,1,1); + %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*'); + % xlabel('Iteration');ylabel('\Delta T') + plot(s{this_index, 2}, s{this_index, 7}, 'r*'); + xlabel('Iteration');ylabel('T') + subplot(2,1,2); + plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2); + xlabel('Iteration'), ylabel('Inception score') + title(s{this_index,1}) +end \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/metalog.txt b/text2image/BigGAN_utils/logs/metalog.txt new file mode 100644 index 0000000..9214b1b --- /dev/null +++ b/text2image/BigGAN_utils/logs/metalog.txt @@ -0,0 +1,3 @@ +datetime: 2019-03-18 13:27:59.181225 +config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)} +state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}} diff --git a/text2image/BigGAN_utils/logs/process_inception_log.m b/text2image/BigGAN_utils/logs/process_inception_log.m new file mode 100644 index 0000000..42e7480 --- /dev/null +++ b/text2image/BigGAN_utils/logs/process_inception_log.m @@ -0,0 +1,19 @@ +function [itr, IS, FID, t] = process_inception_log(fname, which_log) +f = sprintf('%s_%s',fname, which_log);%'G_loss.log'); +fid = fopen(f,'r'); +itr = []; +IS = []; +FID = []; +t = []; +i = 1; +while ~feof(fid); + s = fgets(fid); + parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}'); + itr(i) = parsed(1); + IS(i) = parsed(2); + FID(i) = parsed(4); + t(i) = parsed(5); + i = i + 1; +end +fclose(fid); +end \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/process_training.m b/text2image/BigGAN_utils/logs/process_training.m new file mode 100644 index 0000000..bfc0788 --- /dev/null +++ b/text2image/BigGAN_utils/logs/process_training.m @@ -0,0 +1,109 @@ +clc +clear all +close all +fclose all; + + + +%% Get all training logs for a given run +target_dir = '.'; +s = {}; +nm = {}; +d = dir(target_dir); +j = 1; +for i = 1:length(d) + if any(strfind(d(i).name,'.log')) + s = [s; sprintf('%s\\%s', target_dir, d(i).name)]; + nm = [nm; d(i).name]; + end +end +%% Loop over training logs and acquire data +D_count = 0; +G_count = 0; +for i = 1:length(s) + fname = s{i,1}; + fid = fopen(s{i,1},'r'); + % Prepare bookkeeping for sv0 + if any(strfind(s{i,1},'sv')) + if any(strfind(s{i,1},'G_')) + G_count = G_count +1; + else + D_count = D_count + 1; + end + end + itr = []; + val = []; + j = 1; + while ~feof(fid); + line = fgets(fid); + parsed = sscanf(line, '%d: %e'); + itr(j) = parsed(1); + val(j) = parsed(2); + j = j + 1; + end + s{i,2} = itr; + s{i,3} = val; + fclose(fid); +end + +%% Plot SVs and losses +close all; +Gcc = hsv(G_count); +Dcc = hsv(D_count); +gi = 1; +di = 1; +li = 1; +legendG = {}; +legendD = {}; +legendL = {}; +thresh=2; % wavelet denoising threshold +losses = {}; +for i=1:length(s) + if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log')) + % Select colors + if any(strfind(s{i,1},'D_loss_real.log')) + color1 = [0.7,0.7,1.0]; + color2 = [0, 0, 1]; + dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; + losses = [losses; dlr]; + elseif any(strfind(s{i,1},'D_loss_fake.log')) + color1 = [0.7,1.0,0.7]; + color2 = [0, 1, 0]; + dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; + losses = [losses; dlf]; + else % g loss + color1 = [1.0, 0.7,0.7]; + color2 = [1, 0, 0]; + gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2}; + losses = [losses; gl]; + end + figure(1); hold on; + % Plot the unsmoothed losses; we'll plot the smoothed losses later + plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off'); + legendL = [legendL; nm{i}]; + continue + end + if any(strfind(s{i,1},'G_')) + legendG = [legendG; nm{i}]; + figure(2); hold on; + plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2); + gi = gi+1; + elseif any(strfind(s{i,1},'D_')) + legendD = [legendD; nm{i}]; + figure(3); hold on; + plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2); + di = di+1; + else + s{i,1} % Debug print to show the name of the log that was not processed. + end +end +figure(1); +% Plot the smoothed losses last +for i = 1:3 +% plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off'); +plot(losses{i,1},losses{i,3},'color',losses{i,5}); +end +legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]); + +figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0'); +figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0'); diff --git a/text2image/BigGAN_utils/losses.py b/text2image/BigGAN_utils/losses.py new file mode 100644 index 0000000..6a7467f --- /dev/null +++ b/text2image/BigGAN_utils/losses.py @@ -0,0 +1,33 @@ +import torch +import torch.nn.functional as F + +# DCGAN loss +def loss_dcgan_dis(dis_fake, dis_real): + L1 = torch.mean(F.softplus(-dis_real)) + L2 = torch.mean(F.softplus(dis_fake)) + return L1, L2 + + +def loss_dcgan_gen(dis_fake): + loss = torch.mean(F.softplus(-dis_fake)) + return loss + + +# Hinge Loss +def loss_hinge_dis(dis_fake, dis_real): + loss_real = torch.mean(F.relu(1. - dis_real)) + loss_fake = torch.mean(F.relu(1. + dis_fake)) + return loss_real, loss_fake +# def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss + # loss = torch.mean(F.relu(1. - dis_real)) + # loss += torch.mean(F.relu(1. + dis_fake)) + # return loss + + +def loss_hinge_gen(dis_fake): + loss = -torch.mean(dis_fake) + return loss + +# Default to hinge loss +generator_loss = loss_hinge_gen +discriminator_loss = loss_hinge_dis \ No newline at end of file diff --git a/text2image/BigGAN_utils/make_hdf5.py b/text2image/BigGAN_utils/make_hdf5.py new file mode 100644 index 0000000..b21d323 --- /dev/null +++ b/text2image/BigGAN_utils/make_hdf5.py @@ -0,0 +1,110 @@ +""" Convert dataset to HDF5 + This script preprocesses a dataset and saves it (images and labels) to + an HDF5 file for improved I/O. """ +import os +import sys +from argparse import ArgumentParser +from tqdm import tqdm, trange +import h5py as h5 + +import numpy as np +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torchvision.utils import save_image +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +import utils + +def prepare_parser(): + usage = 'Parser for ImageNet HDF5 scripts.' + parser = ArgumentParser(description=usage) + parser.add_argument( + '--dataset', type=str, default='I128', + help='Which Dataset to train on, out of I128, I256, C10, C100;' + 'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)') + parser.add_argument( + '--data_root', type=str, default='data', + help='Default location where data is stored (default: %(default)s)') + parser.add_argument( + '--batch_size', type=int, default=256, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--num_workers', type=int, default=16, + help='Number of dataloader workers (default: %(default)s)') + parser.add_argument( + '--chunk_size', type=int, default=500, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--compression', action='store_true', default=False, + help='Use LZF compression? (default: %(default)s)') + return parser + + +def run(config): + if 'hdf5' in config['dataset']: + raise ValueError('Reading from an HDF5 file which you will probably be ' + 'about to overwrite! Override this error only if you know ' + 'what you''re doing!') + # Get image size + config['image_size'] = utils.imsize_dict[config['dataset']] + + # Update compression entry + config['compression'] = 'lzf' if config['compression'] else None #No compression; can also use 'lzf' + + # Get dataset + kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False} + train_loader = utils.get_data_loaders(dataset=config['dataset'], + batch_size=config['batch_size'], + shuffle=False, + data_root=config['data_root'], + use_multiepoch_sampler=False, + **kwargs)[0] + + # HDF5 supports chunking and compression. You may want to experiment + # with different chunk sizes to see how it runs on your machines. + # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128 + # 1 / None 20/s + # 500 / None ramps up to 77/s 102/s 61GB 23min + # 500 / LZF 8/s 56GB 23min + # 1000 / None 78/s + # 5000 / None 81/s + # auto:(125,1,16,32) / None 11/s 61GB + + print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (config['dataset'], config['chunk_size'], config['compression'])) + # Loop over train loader + for i,(x,y) in enumerate(tqdm(train_loader)): + # Stick X into the range [0, 255] since it's coming from the train loader + x = (255 * ((x + 1) / 2.0)).byte().numpy() + # Numpyify y + y = y.numpy() + # If we're on the first batch, prepare the hdf5 + if i==0: + with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f: + print('Producing dataset of len %d' % len(train_loader.dataset)) + imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']), + chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression']) + print('Image chunks chosen as ' + str(imgs_dset.chunks)) + imgs_dset[...] = x + labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),), chunks=(config['chunk_size'],), compression=config['compression']) + print('Label chunks chosen as ' + str(labels_dset.chunks)) + labels_dset[...] = y + # Else append to the hdf5 + else: + with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f: + f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) + f['imgs'][-x.shape[0]:] = x + f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) + f['labels'][-y.shape[0]:] = y + + +def main(): + # parse command line and run + parser = prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/sample.py b/text2image/BigGAN_utils/sample.py new file mode 100644 index 0000000..663880d --- /dev/null +++ b/text2image/BigGAN_utils/sample.py @@ -0,0 +1,183 @@ +''' Sample + This script loads a pretrained net and a weightsfile and sample ''' +import functools +import math +import numpy as np +from tqdm import tqdm, trange + + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P +import torchvision + +# Import my stuff +import inception_utils +import utils +import losses + + + +def run(config): + # Prepare state dict, which holds things like epoch # and itr # + state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, + 'best_IS': 0, 'best_FID': 999999, 'config': config} + + # Optionally, get the configuration from the state dict. This allows for + # recovery of the config provided only a state dict and experiment name, + # and can be convenient for writing less verbose sample shell scripts. + if config['config_from_name']: + utils.load_weights(None, None, state_dict, config['weights_root'], + config['experiment_name'], config['load_weights'], None, + strict=False, load_optim=False) + # Ignore items which we might want to overwrite from the command line + for item in state_dict['config']: + if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: + config[item] = state_dict['config'][item] + + # update config (see train.py for explanation) + config['resolution'] = utils.imsize_dict[config['dataset']] + config['n_classes'] = utils.nclass_dict[config['dataset']] + config['G_activation'] = utils.activation_dict[config['G_nl']] + config['D_activation'] = utils.activation_dict[config['D_nl']] + config = utils.update_config_roots(config) + config['skip_init'] = True + config['no_optim'] = True + device = 'cuda' + + # Seed RNG + utils.seed_rng(config['seed']) + + # Setup cudnn.benchmark for free speed + torch.backends.cudnn.benchmark = True + + # Import the model--this line allows us to dynamically select different files. + model = __import__(config['model']) + experiment_name = (config['experiment_name'] if config['experiment_name'] + else utils.name_from_config(config)) + print('Experiment name is %s' % experiment_name) + + G = model.Generator(**config).cuda() + utils.count_parameters(G) + + # Load weights + print('Loading weights...') + # Here is where we deal with the ema--load ema weights or load normal weights + utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, + config['weights_root'], experiment_name, config['load_weights'], + G if config['ema'] and config['use_ema'] else None, + strict=False, load_optim=False) + # Update batch size setting used for G + G_batch_size = max(config['G_batch_size'], config['batch_size']) + z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], + device=device, fp16=config['G_fp16'], + z_var=config['z_var']) + + if config['G_eval_mode']: + print('Putting G in eval mode..') + G.eval() + else: + print('G is in %s mode...' % ('training' if G.training else 'eval')) + + #Sample function + sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) + if config['accumulate_stats']: + print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) + utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], + config['num_standing_accumulations']) + + + # Sample a number of images and save them to an NPZ, for use with TF-Inception + if config['sample_npz']: + # Lists to hold images and labels for images + x, y = [], [] + print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) + for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): + with torch.no_grad(): + images, labels = sample() + x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] + y += [labels.cpu().numpy()] + x = np.concatenate(x, 0)[:config['sample_num_npz']] + y = np.concatenate(y, 0)[:config['sample_num_npz']] + print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) + npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) + print('Saving npz to %s...' % npz_filename) + np.savez(npz_filename, **{'x' : x, 'y' : y}) + + # Prepare sample sheets + if config['sample_sheets']: + print('Preparing conditional sample sheets...') + utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], + num_classes=config['n_classes'], + samples_per_class=10, parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=config['sample_sheet_folder_num'], + z_=z_,) + # Sample interp sheets + if config['sample_interps']: + print('Preparing interp sheets...') + for fix_z, fix_y in zip([False, False, True], [False, True, False]): + utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, + num_classes=config['n_classes'], + parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=config['sample_sheet_folder_num'], + sheet_number=0, + fix_z=fix_z, fix_y=fix_y, device='cuda') + # Sample random sheet + if config['sample_random']: + print('Preparing random sample sheet...') + images, labels = sample() + torchvision.utils.save_image(images.float(), + '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), + nrow=int(G_batch_size**0.5), + normalize=True) + + # Get Inception Score and FID + get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) + # Prepare a simple function get metrics that we use for trunc curves + def get_metrics(): + sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) + IS_mean, IS_std, FID = get_inception_metrics(sample, config['num_inception_images'], num_splits=10, prints=False) + # Prepare output string + outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') + outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') + outstring += 'with noise variance %3.3f, ' % z_.var + outstring += 'over %d images, ' % config['num_inception_images'] + if config['accumulate_stats'] or not config['G_eval_mode']: + outstring += 'with batch size %d, ' % G_batch_size + if config['accumulate_stats']: + outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations'] + outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID) + print(outstring) + if config['sample_inception_metrics']: + print('Calculating Inception metrics...') + get_metrics() + + # Sample truncation curve stuff. This is basically the same as the inception metrics code + if config['sample_trunc_curves']: + start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')] + print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) + for var in np.arange(start, end + step, step): + z_.var = var + # Optionally comment this out if you want to run with standing stats + # accumulated at one z variance setting + if config['accumulate_stats']: + utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], + config['num_standing_accumulations']) + get_metrics() +def main(): + # parse command line and run + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh new file mode 100644 index 0000000..0c4831c --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 96 --D_ch 96 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh new file mode 100644 index 0000000..114023f --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 512 --load_in_mem \ +--num_G_accumulations 4 --num_D_accumulations 4 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 96 --D_ch 96 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh new file mode 100644 index 0000000..650d2fc --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 64 --G_ch 64 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh new file mode 100644 index 0000000..5e83ef1 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh @@ -0,0 +1,18 @@ +#!/bin/bash +python train.py \ +--model BigGANdeep \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_ch 128 --D_ch 128 \ +--G_depth 2 --D_depth 2 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 128 --shared_dim 128 \ +--ema --use_ema --ema_start 20000 --G_eval_mode \ +--test_every 2000 --save_every 500 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh b/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh new file mode 100644 index 0000000..0e99a0e --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh @@ -0,0 +1,13 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 128 \ +--num_G_accumulations 2 --num_D_accumulations 2 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl relu --D_nl relu \ +--SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ +--G_ortho 0.0 \ +--G_init xavier --D_init xavier \ +--ema --use_ema --ema_start 2000 --G_eval_mode \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--name_suffix SAGAN_ema \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_SNGAN.sh b/text2image/BigGAN_utils/scripts/launch_SNGAN.sh new file mode 100644 index 0000000..a2c9d66 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_SNGAN.sh @@ -0,0 +1,14 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 64 \ +--num_G_accumulations 1 --num_D_accumulations 1 \ +--num_D_steps 5 --G_lr 2e-4 --D_lr 2e-4 --D_B2 0.900 --G_B2 0.900 \ +--G_attn 0 --D_attn 0 \ +--G_nl relu --D_nl relu \ +--SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ +--G_ortho 0.0 \ +--D_thin \ +--G_init xavier --D_init xavier \ + --G_eval_mode \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--name_suffix SNGAN \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh b/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh new file mode 100644 index 0000000..07e15c2 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh @@ -0,0 +1,11 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 python train.py \ +--shuffle --batch_size 50 --parallel \ +--num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ +--num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ +--dataset C10 \ +--G_ortho 0.0 \ +--G_attn 0 --D_attn 0 \ +--G_init N02 --D_init N02 \ +--ema --use_ema --ema_start 1000 \ +--test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh b/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh new file mode 100644 index 0000000..c228da6 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh @@ -0,0 +1,20 @@ +# use z_var to change the variance of z for all the sampling +# use --mybn --accumulate_stats --num_standing_accumulations 32 to +# use running stats +python sample.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_ch 96 --D_ch 96 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho --skip_init \ +--hier --dim_z 120 --shared_dim 128 \ +--ema --ema_start 20000 \ +--use_multiepoch_sampler \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--skip_init --G_batch_size 512 --use_ema --G_eval_mode --sample_trunc_curves 0.05_0.05_1.0 \ +--sample_inception_metrics --sample_npz --sample_random --sample_sheets --sample_interps diff --git a/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh b/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh new file mode 100644 index 0000000..e6db0a1 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh @@ -0,0 +1,11 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 python sample.py \ +--shuffle --batch_size 50 --G_batch_size 256 --parallel \ +--num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ +--num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ +--dataset C10 \ +--G_ortho 0.0 \ +--G_attn 0 --D_attn 0 \ +--G_init N02 --D_init N02 \ +--ema --use_ema --ema_start 1000 \ +--test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/utils/duplicate.sh b/text2image/BigGAN_utils/scripts/utils/duplicate.sh new file mode 100644 index 0000000..06da14a --- /dev/null +++ b/text2image/BigGAN_utils/scripts/utils/duplicate.sh @@ -0,0 +1,14 @@ +#duplicate.sh +source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0 +target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A +logs_root=logs +weights_root=weights +echo "copying ${source} to ${target}" +cp -r ${logs_root}/${source} ${logs_root}/${target} +cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl +cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth +cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth +cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth +cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth +cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth +cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/utils/prepare_data.sh b/text2image/BigGAN_utils/scripts/utils/prepare_data.sh new file mode 100644 index 0000000..e75c660 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/utils/prepare_data.sh @@ -0,0 +1,3 @@ +#!/bin/bash +python make_hdf5.py --dataset I128 --batch_size 256 --data_root data +python calculate_inception_moments.py --dataset I128_hdf5 --data_root data \ No newline at end of file diff --git a/text2image/BigGAN_utils/sync_batchnorm/__init__.py b/text2image/BigGAN_utils/sync_batchnorm/__init__.py new file mode 100644 index 0000000..bc8709d --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3fb1abe1a3d347d8747277f14f4885e10107086 GIT binary patch literal 348 zcmZ`!O-sZu5Ka2ER@s}N_zTQsi+UB&DtIfx9)w;(nnW8+(vqf!^&j~&{1rm3p8N}* zoPfKx9hk?PdAv7Fy<8q5sQKlq`@{%+$6;FukXz{K96%s}H8QxxCTSCswkeLgL}w;% zb5pbh%;_4D><4lXNyEu{V|nietKOxkndf~oICCNP2$%bWD?yTQ`oKJUVvcDo^|SK* zJ+O}Pw{SMbXzJvs6=z-zmhykCA&XFUQMPc|#eZBHAp;2I)*@XVmj%EFa%JQZe#xq; p)$8M3Luaq_R88;?-NjXuDyuxDt8o(!uFW83EpLv%IK@N9_zziiWD5WQ literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78a61f6a3ab8d38e3822ae6bc8ce706a4eb77a3d GIT binary patch literal 383 zcmZ{f!Ait15QdYsTdlG;LGcOfr9r)l=(?z&ZHnvP!Sfedgi{o6(;RaSXQt9~7}&rByq NEzftscZmPk!{4rUZ=nDH literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee8b6a18ba04e189d0da32a9754b558481140ac GIT binary patch literal 13007 zcmeHN%X1vZd7sC=u>?U9e9EEQc4T20VEGc06pDgGf|lee8BvK8RbEjw7)4mmrs z?3pD2Xt$CWOxXv^6<2bJ4=y0L`~f-So>LCFrfO~}r^>}+DwmWm$?xl)9V|$KB4aVq z0kbtT-90_sUw_YEe>>L?9jX~P2JL@a`|hG){DvyyqlCgcIOBi8g&W)qjF!HeEmPI4 zo^2VX(JQoyCOxr&Qm@=9tFj$bdev67S8LT!FYsb8)vLGaD3^GdS2}j<5U=vuBcnCV zr+6Lr!@}5};fMG%Kg?%3_Kwv$qFRoiLqau z^;y(U@G|OURX>jUNnSy{qUtBOdDCc|dW`iLD-AQPIoI99>#jty*Rax~9|=D2#5K;B zBR3GSCwN-DE+dzFZk#+e@e;jrb0hSAga#qglT_}?DD*c4&xg9m_jzz>zXzAON2Ak0 zcN^z2&UhA=p>Y?i=hhb{?l#@Agu>?Q@u0_|2ST!>D_Gl?al+h?vjCJa?)OBP_)#b` zqHVN1{vRAzfQ>hCD3oWL#v##>`FhYQRyw z8d_W8`hi-9oWe+;ygCTu`vW01MZ=Vb(56@@cQ27UF8dG^xKVI(o0soPR zW49w@o!Sa9mXua#T``22PCHt}S*Nt@hYy@M;T}5YMQNYSD6<)0b_qg4S*j-)4?XX?{$|Fip>pZ(*1|HsmAfAx2C z{L|aFyka&r-N{Z#Wnvhxkr;_NG`6g}*3cYULwi`@<`)IiNNgxlVGGZkBDZiYaY)4} zt80ZDB$Jhfo!X)wr%$5-6#G<9N{a7(|c;OA@{OHZE`AjFnq6Pg!-VWR}d$Z$F#0&G9-`ysgZ8QLI>tJVHzMIyR4|2#IrFA(7som)bH%zC-wFjk*Db?)6T^A;EoSLgy%b?e)WkvEh1zLbSNj*w<5o0~+%*yRZ zK6E9&fIe{(7sINX1@ovmV^*?X!LC?m>{-<6cFjCvojP7Co--@f_pMob^XxvFpAGAg z=b(e`bS~qJFX6(;R~YaihC|wo5Ck1#D8aKhb~rsW@7mkYaOf@h6zxH7>4;NO@|$4s z)PyH!O-s@7DS&3|C3ZJ#%~0D*)>~yS>Tkd=JjsjKoF^QO6#c;W+(bA|NOWtfhsA>r zHEZ+IK5Ku@FrJEM7>fc+6{q6IzoL-9W+@N@P1s*pBkkhRc7v?Y0;(xK0=KeSS=ID2 zJ281>w~pS0#JX$4D^yP#ywt{-Ee=b=!j6rW@~Dy&;d5-VX?TTd z#~9Vtmy#0sfg8qWm(XWwR3Fy1%fmX2msEz;;S{gIYSoq3P8xq@oi=`9v>_?exc|4G z7{eO8{nY0B8trvIe4D^~-zx5nM25nlQ) zoW|zTeiD6F0BT+hBF_!tB^2~Pn8v;*HgbZASfWS~Ou?lCl0Taz=gSQU2jV8?{nEM* zmkkxAzAJJLS4As8qt(11m&SFy018CEt6dGiFJj@+pQEk-NXgiI=l@Foc(9ywvQ>&% ziu5xmDAJD->ClVH2wL?^`+l7W-{mWoa2;J0#i@CB_koU{r0aiyGov=2i`cZ!T*M|V z-;7D-q=@BPcoX8zCweCmj`>r1hj4@y~F9+cNJJ;+X_ce#`s` z)NN?LMj(d(XJOVjMWCigy-pcuFHsryGKG`7XrllE=Y;aGaMzwSPNM8s6HhUQdRoe# z7KQ~l!4g~`fW=8;$81||z% zW`mHRh~VTzpP+1>J?#2kmq~Gd-~;6{0x4{XU2HZg>GZWY7Lpe511Te^vCKR~kQT_> z$ptDFlSq2qv?$~IGEpG8QBhhbKfq(HSOguEcC|`i0zgu$rfG7hTp{YH^rTw7hu8iF zXG}Xd!zR2@$FE|YHP4!xbNi{&lO_Y7a~khU@=>=C0#;yg*gu4B4^bOiaHkjXK|tlY z;~@M5`iX+_lD0JRDqR>|h$L+9H1PLC5RAb;1sm)J3HobSMyBw6oC*?D%;{=vsZy)F zrnWC0+Vg|#osO8{l5gV7maVpN7PlW&S_QR9D?;fCvW#x-Z6Jeda1lO4Ho0eGO8krx z3U1(UYVA|IbeCP?ERyUpXA1s7gRt9$_QgU#Z?X4P3fKeOT~I)`SyK(oCiE=nE-vaS z>F5S~i`{8U*ZX{(EkJwS@YD6rMxV1gqynGcb|tFpkJugToIKV%zb>!M{{{X?VYbU$ zWl6*yO1H152}SZSlJ~M<6;0!z=f|sfBW6C0OnX}`8rn>dpft$Q84)sX;bP6&urWF4 z)C2ATLg*K4XkP$H>ds?(c0DMk3J_dTr)@&d2_K+FgY1>75pKLeOUX%_XKMqU2-!tl ztVf#~h!5I>!3Pw3nrfIIhCtFTpT`4CMG7^MlVb2#F&xOeqGDpWhdxAs(K2Jz6lq1p zgE8k=tRw+p0e(myWI>i-Pj3(;J{UO`W+uFqNTEeHnP=EtprcrL@57Ht>z6pWR&07YX=!~0(S#r-V?B4=CW*m%LEWOya6Im8@M~prKgsznFcc9 z%6&IPNJ4mn3F@7)3V$Htp6Dk!9=ecT+l3_AINL+RUT9w2OPF}z1_QMZ9(+VYO;=7qSud*rrPyr+8th#op@0VO)hpa z^R&39rhO!!hFI@I;#jd3X(sY~1R7;kIdYXyWb1oJ`E(sZ;7Rwh)%*>JdJ(cFb_<5& zD29etAf@3VjIbIP`lb0x;1k>4fVi^e2Y#}leF05M#|-b<$qX+3LE1ak?qb*ZSQ{CX_l<07`YdP@$( z>gqh}_~ZtN4z2vTk>!KZ)Oz4uyNf3X7>}hTdXthbyCPj+KcrO6B3o4PRB6vLvD?U^n}*+eTx|7K_V=xK5vau8S+v%_1rY)sG~GcG zb)?IiKp#Fi#yQJ{$bLIrfuJ!CoMnG2~Th z3`aV~$g#GaZyE#!kUxOHDi?SFfe#??ZYVy0z+Vdlet({Q^0FxSW*^x6Re<0eiBK$8 z$*lb`@H*?Nzv_UE6&B9lJwU?$|44YAX+6q?@NDQi?)Eic-Ak_s>*{g8VyrvQE?=2v zZ(jL5FfaWFJArx60_R#UfOGAY$As=Szmr4UH=hq~<&V)_w&MUyfqso$u$Ulg1z!G2zCs2zo7^g0Q&&J9w69* zB691YgKLkz5sBqT=#U=MiFf>0-X4VNfg7Y&7^bz| zx;%%TtuuK~{ZDM%R9kiW^G;Ioz4fiU!}!DQrZXNSR^LhKub~R`pPT*=9GPhR(}qkm zeKMjyo2}9396TA|?_&LNEqy7aKEqJSbE&o)it%~QT4%4-r~H+xnz$wU_CKaC^eR>b bUp3Xu8GPDZvn$0)5ud`^GsBs;XWsr_mky!u literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d8d62fab651c238c5997217975b3541a0b8cac GIT binary patch literal 13039 zcmeHO%X1sad7lT~geXd)p6keMJN6PA5Y%H;Nf<3x61m#licN1>T5SwV84jnLz(|7` zsAnLNT%eMeO4*03D^5TH4%q`_qxpYkBlFA{@@9PGJASJ4#*vjsK zU`?aDr|0YL@A>Nw!%vSKsTlY){`DW+Z_XRWFQ_s8DWGu;SNtz1xWUc9XzFjXX{xr> zu`RPw?p{2A>wiExUOF?UTHKc0rs# zdm8Oiyoh#DwNIjbnwQWnsrD&u-ZE-uo?$=6a?MOD&J8#5+RKsb)U5RQ4+QUd;yUMx zksFBE6Fe>7kdezhH%^|Jc!-|4wGn#XLx+&*X{xtn6#AQjXQVFj18!V8=*DI45%tlB z>Nc)LT=6uDzHuMy=hhb{zHR!(HgcQaj(Z&zJra^7ZNb)j87IsQISa5l#{G^66F&;& zaN#z3p8X0RY{JHqPtCNDh%k=iQ$rrbEdTxSTs!KBxpgn_dvhqdPD_MyH+!*IlEUrG zHRf*ktq-p+JH5mY;<*@`cUFmkAs9H9)f(N6wCG?4XDvBHTd+_VmOO^?;C)WvZr|9( zK6gx%)(-8_-mJTvvrs%{TnzUUGdIy22IKXHY5R%m2WqQw0+g`FaxaV@_Jr6J@(6m= z7Ii^WrAM6XKGykL=o+o_gucDl_5=Ps5yx&z$SU>aVr(ic(avH>Je_p39t;PiML&Gx z#0mE>c+V#{Y{l155e8XL)8a>eVyu~D8BfYhn@^EdJTaTOp;hFER#6~}fOQmqV&sfw z2~0Vbo^oEOc)$FQ_Ah?=kN@@W3%~r^zoE}Ry?@)wk1d4iWYnZvF)ZFljKu64Th@K6 zZ}zRe-Oq9Ji=1gBHZ(4`g?mn(TPO=0vT=&4EOCP*vs|-NTXf^J)Pd532{==5*Vg<{ z+^m(;eBA2_DbHe(wB$H`=qHXNY4Y@>!`;Mnx~_DCKm<-l+fvn0`joZ&I2)2yv#R6s zbi#=P_mObAepl4;kctS_(rQThCfuYaMJ&IDi6qHXo<{*qFFTG0vxu<>QgulxAQ_6c zamC6HnrE!4RWJ+Y=GR}$-o|Jj%ic(!c-SD-AOgx-@50!~T4fcb8Q)2)Z5zA=H*w{0 z6>t@CQC|sHd8ahU^(}5ZD9amt)J_{{+p7Hov~AVS-7@;-c4fyHzStK!ILdhjh@~+itZ04`V-i5UU`K4vNm$kWo7Bl$h*`|$#3B+EhW;0u7$1C zyqQ*%M{>wHrB&Grso&`cH%uo+tw*j*3y+20YA0!K)puj9G^9SzE3)z2BVRmLu1o$l zP5Dg}%?h3$>a8UIEWb^?W$IO`DU}N;j@%FFV67^@gReC2hut0~Yhj)8BK6o2+7qhV ziQ6vB=9Ho~?CEuym0^{-GH~)&z-lNeeOq!Zs2=vA{-~Ab!%S$KLB1-tj$XY?EfWXJQeRi zi#`;3%D5Cy2uLNcWQyegCH5EAKnDQyEsp>+h_e8L;5|0A2=s&G!ra+>VW&;-Hk|;fZXrboh#L%NSIW0{Mm|5+1mF;4`O4O24zuceT6(!KYwY#^$yE>n=v*_H2MGmq@D(oI^!z;5cmnI#L=y zzkY5%tkUv#IgE{bfT3#Lsd<0*hK{f|aqmxY#U&K9X1eywN^Cn|R$`l$xn|l^Q!!00 z;1Rh*g$m`n@tkJ+Gh7*m4xWiztlHn?NI>_8qLClrR?SpS_`q4V#%Iw1i6jZ&QvjY$ z1Xr3ViKS`Y?RJBW^q4a=1{Hz`CQ90OS{6VqLc#$7KXtnzpBG&gOWAb=%Z)?KZSZfci)~iPNVKvV|PJA z-7REybNw8=VFA7nxZ||3W3E|blAC|AIOeV>&T4;_;7o;hI!i7FE-VV!V?SxL=lxxS z-ER9a>xiTs@d_J$;5j1JLAb{WC5)9T>_m@%MHoP=#NmR0A~{G6eR0qsPu5VnzWGsh zsvP;!nDs(JDuRP@fGv5d@?2fpRJ{G=vtb&M^ev6gnuGljX`n zsZx4N^?T^w^C7FJPs~tcr#-W2tMi`4?I)#XPMzlRP`jKgVwigeAmKV3g%6>PAMTh^ z07fYaH}E&L{;4y+%Pw&iNp_htg%6=aI2c3!Vj-ZpI3_C<>;MnXDkR*fsE$ShnwGTZ z=XH~`bc4Oe?ygDK`+S|vLVw-xN9*?npR>E91fSn=C7SGa*j*g={2&Sk1NIJUxt)%S z&Q+1Pnvzz2#e8HR*4BiK*_wi}z_P%z{)Vzv^7jU{OQ zC&Q!Y3QHpPSh`)cnouNQ*Ys~98q*oP=UZD_7~ zMOsmDW5hWYD@j0Dz#!7cvLH*ar_&1(AB-Fcb11y!NP$N;sWTinFi^~X_{pcF^^2c; zI;+@)+#Cd})~O92i)h2M#w2ww0YfpRP#YVn*>kFx@kz0BIQ4Dv$S;Wjq>Mm?UM~RS zIK#1RR;ApCIVvfa33%56u_(&VY15$^8IMxyK)5n=SHS?viy(p^srN{DaK^Es6Rl}i zzQ|^v!_$mE#mT_4wCKv>iq3^t2%7$-u3V~U@^!0|+fA~aB#A4WmcQnm%rL#xR;Zi`MgLN069qiJ_|UAE$R zMVegfczSAnPfPnqfDN(Ug~YLAEz+ULGYUjyMLBYnePru8tN z&>rw~yho{2HsdV&zRKFX(o<}Pd>2z$AqtX8F=~1-ALy7ireX^V3+!9$Gs33d&D;sa zHswyVAE6D?xDF!zILycl=pvuFc;P}VTkkGW936gbx;Xi(6!CjE5vau8S+v%_2N41u zG~8YiwWP}%Kp#Fi#u?o&;WopWOBd>lmbsULKHOkuotYu0XuEsTF{4vDVYHYIl1QN% zJ2;&`u;n|$xX029`lx0dQcgQjo7*7j`Z%`xob=**hFr9L|w(uSxcMc_U5L=M% zAPg$6P@%Bt?$P&;i5{*|DjR9+!mkR}=F9;?Zahc4AU`9ze?-L(iCPgUuE+x*Cfc$| zg_>_4G$!==_qf!dyJY5m*F$5^hj>GuSH;70M0Br>;bEnmFNVPHj=be_;1~F21bme; z!<~+ha;z=rs|JJt=MRCfN)R3b;X@$28;}oy@RtIFKdQ6OUKb1B>H?v^2`HT55sK(4 zsdX?CUT1CfiyjcNLd4m#hj{q^9}m}=)}u@aF9yJ)VP68`z4V3K|75n9jd(m#Gb_q z6hjbOeL0A|#GeDPa)qeA2&K}*|Ay-~hf?WBU3cTvQ0kS}L8&81`xQm00MUmi^$?{V zrfUz=wO@sF?NWvuUMp#PCkl}A#4ixuSlX7bC6adv#748p8_VlvSm>3&EJuF54J>;I zYU#(juSAl!&N8*jWO)#-%|?C+(Dw2hg0`CO8-li!+5OVsExnpD25+yt9=s(?_;Psr z!&ip4mtP0o(!`8N6Ff%}_V;S1(+Mxa`_&%3?-9#S(Vw2sDRlhmz19oW4L3-wFib1E zZCS_2=DBR7{!KM*sN**M3@0gt-uzlNVDuh$!x<5Y)q79+?WP?4dZyp|Mm`z8nIWG{ z{|BJokgd?G89eFXhhP03Exq}qUaC+jaH)<9imAJpUe)NOjI?xBFK&Te_>bw8yOLGH bTS`@P3U6{(>{7mzKWJ&VT;CfBblLwPWEo8~t-jopF!t;Ory`Dk%R)IPX_y0A|h zyuq7e`^*_M^<9g%c?VBiJZWM63U@Kz6)ntf%2OSgXg|Vg$!tF@pEwT2#30lL0F_Pl(h#gDg4nI)Od~8 zajuh!I5&6;;@XnFn?K%W>7Eom-%|cau#uL{!pSre6A`C=8YXd%{ZR?dk~lhMzQS2- zrO6gfYzn=4fsKXeM`stu6CCO;25v1Z&r0o6=ZvmgT-6rVzI$fluFfr}VfVX@tm(PxIQHGs+8JE(8eVu|y|n&}GwgNx^Dn-W;xJ5R ziW#ffp0B=6*{%>V8wsc^c&}TREr^kI4B6VYQbLmTzm+nPlJYof#7P>Cj=PRr(YIzT z?F}=m(qnfzOFzSfqSY+BW#53k^YwGqVoDjpbB6vBpJ+qBS{f=p#}nt&CDX|G6PT|J zMhaKVL|AJ}3yj9KBR@5#N!hVT8DuI7DJk2rJ5J-W!kE_YHYu2~KM|gnbv$p9 z@L5FH*F5j-%#ZRnP0!;=;CXTlE6N)*Q0Oa?NtPsE>3X0U+E*>R=Ctaq#$e4<+)FVI zRVt4qpaT){%CdayE0ku)Q7qvo6for#&axZ}-@sHqo`d zr3e8SrSL2c15=S@X3{4$96B>MaX|q)rTlXi_Q*N2b(nN6o;c}muG#ZXN)PqQS$zn*O{;%O z135!((C`x)bRd?44hdP;!=cDT3Q2%-`J=awaHyvku6FB1;NxuoU=zr^f^$o>fy1jG ztq=B~!`I-C>yiZd3}vt#Gk+Hm_k)xTm&iM0$?m=g(q5;-@Y`ZfXA*D)5EUrrX_NkeWKV|_Q@A|<3OGZri!5%YdjIS}z#GL~x0->1N zOJ)(*Lg_0ulsY4YydU!J5Q_nzcT+!v+A$O^SYyc2q>?)VDx`Z_Gm~At%Vr?;j(o|X zwNjgT5;CI(3**D&0NM;Gre=kq(RLz~^2c~JlF8(P?%M2K%3*jRNCn^2Qb}hL0&anl zKKw^DCM50)#}w*qkMJB50ude-?k*&5R?I2%G~QxkA2E>neHg>8iLM^)EH1F(Bm{{; zcD_`ujaE*F1+R2i&V|D0R8}BCCLK?3sCyVR15{bBc3M9pkZf?`h9;-VbLa0(22mDm z%dvW>rP3Y!jBCBZ4G-a+;+ph@FHIyk^R;L8eC?Lf!G%Rvg~RAiQl9m@t!zDyV<}AI zK~{f3FqSodrgUD(mj(%`Y~>GFsday^bz3#=dWL|q*CU(!U>nP z=QKezh>StUpDYV*NygPSqjWfNA-Q$m3aN>5`toZ$>DrQ+6OrQunM4tw5QXLjgwe+Y zMaxIl~t!kZpw@>RQ|tny6Bf7npvvCd)ac zCN3tIn*2^{vTw=zaD?u*)Sc8d{Rju@d&CMl3vDZHe6UWOw7GCTklSht`x`55QO@q$ zc;@Qd{8yQbhW89wRjVM!Kgy!Z*}G&zByP->x0`IHK+fnR!yFzpx?x6Ty_W+0Z>%Wi zt2kzGmmTeez{vp1q6lU=b}Y&962Rq%P+5Q8)2U4=pod~XB=ll{2f?1T0rXVA-nE9*PLy6&OQX%*SBN^J# zD-?`-Ka7KD#!I@s1cVcB8y6lE*P&sX<#;T6rDEcW~oXJ0( z4|T}lUh&SSjcLW;5-a|XSM}>^04r?NZCM|QmP0l*YZfu0tigtX~TEpqs>-J5IcbqlsrX2gWvyNkatBN9JA6M=H z$*^<$ni73rbBn6PzRtcZK(klytNaBP_$v-WX3qR(MPBKlK74o8UVu=%8t*K0=*ajc#^(!vW&d|*S5hCr*HFUPx*ukx|T>RZ={S?ylHs=dkk1!Jd zbnw7L*fBN?vls;L(8iIcGo+c^qPgTt4dZ%iaFy7s^^IxUe?!ht>lj*acgyLx>vz`d I8}54Nzv)H6%m4rY literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0143006e3de71be017a68ebb26f0f5e4a6f1af5c GIT binary patch literal 4819 zcmai2+in}j8QufWI@t9gvD~znq%LB%64eF>P#8gB*{TB+wHnJxQV1=UtDPaa(sGxc zSxOc`-PlHY*;kOzo4T*kH(2zpS9yhA<@;xL$)zN;6lQjIb`Jk(zV9D>y|U7>@az2R zAHM%1%lZ$!lz%n6JVsH+sJO-L$l5ngd*9aY&c4GP?($0C-FLsVc$GWvEbfTPTX(-I z?DHD0@p|9BaQ5r^t-+hTg&r3@4UAvnF2=i}f$>eeFY^lCD|&v5+b^y5%JHS$a97SR%;?z&WB<`@UmEbIiqf_Q9 zJjF(uY~aC0kgXHgSg80>+eLYbqCP>zt*PZ%seSHTP|HQDGPMrf3mb1$ZoRbL+dJ)A zR`+DCwUH-``&s1=NpO@^y@(#or%B91k_SBj+VAbGqD0htZ^>0`^xy5xK{6DZuY)KY zZ=xD`eGzXy8!NF51&5oR%_m|1%dH)6oQ9FwRIsRb;HSYLPULViNQT4C=rn71UKod| z=iMa@C^#&;X1DCg-Lh{w#i)R+Sp2BnMA6Fs1`Q}`{q3PmPk801q3);f%v<}7Jtg%~ z4&C$0g=JgcTAy1ltX=C*Fc`e!>`Z?4l@!NeGFHqOh7El6Y03_Sh*?iSY{5J2s-*oS ziJ7f#DJ3Ks=yy^kQZ7+vEl$#~ciML3ihea~X@i;S26elmae5yO)j(z04f_^!p4?o} z7(*63P0G^A(#k$TLo55)Tv_=MdYp5Y)FT6&!30`q`&-M&)%`n^ovSUYSf*W>>tuAO zocLi{Xy8JFE8_P>$}m+eqayKJcFU<<@+KWi8JeCVU)Gd)f0|jr}P9 zQujQb1fC~vVnum}st>86NGnxjJ$;aFp`oqUuq#fZ+NkZXn&Y}F`k_kYsl1I5*{U}? z(@%Vba~*P=valBlnN?qn#aILww(ui6KSojaP+fn#5inKh;`8JIp#|$Dk_|OjY$$>d zM!`Lf!@wNpg_xyj6%_5pbu`G0=Mz`tpvC&Vc zMVe}BwOtc+?ojnJI4-?pxvVM2v8P4+HRj~ecMI-YwkNk|mg>w#?O-!9gIuGEB3AyG zs*kC<&NZZ62St&JtB%=`m*-w#sOokMt+B?pHhNIR_-!*7Fak0tO9NGx#JM^WymbwXR}<9_P1 zkuUuy5)nHIqll$4?DvIae#`=NANau$OL|QB!GM{N#%zo;{^vlFKq#gL$vEPAQ<}xP zQb(td4@2JWVlklnLF$JPJBGjo{d8HHlwwFggmj=KGf~#Lss=*v#FrdWE47{nBQs*K zFg{L>AWfJ2YF6kPX@^27zmHEnnG7$4}kBgnXj!g@U>}XMlQ6v&Lu_%+U40!yOG_@eL34q<3d(_L9mw9 zfUfjh%9jRBsch+wl9Q{p!vN_HG5~{>z-646@Ysi!3*5hPNt4oX*C0$dtu3c1tVWa# z(!M=EbaTqSoUhD`PIN}5J+wk(W$&FG*+x&>mY?7uiCix2N~8f&QOJ`(Av+kLXfey4 ze7Laje9RS(6IN)si9x2;IlRv8HH#3IG6DGubVCACuP{Yb?{|$TY9`7H48LBK`IrS! zTuhjW@_W3H734$II&)L^Qr9Fd93rmI0bCrQL_%;vA_Y^=c9$VPO+k z(C;(k-ju3KI^$VY3`gmy#LdK#pXv@;KQ?-NR5t{AjgmI=5>fIa;+RYo#b%oUsrblwMM5rsMRVXlVSw6^Jr~z76-2yAc4} z)?eb3?yO-Zl_N*(rgiS>e7J(S5B}_&xo6cWZZ1cT z9K8Alt2CzWnswMbZ%wOut#3+ueSp34K;$;E>+jO~8~fFF`5u)u%A~74$B~@H-*X~2 zq+Ie;XGz4a23QKvr_V&9WdfT6-5!lENa_U8#d@JsaC$*XhPm_#1?fHv;~*OI8RKY= z=G(Qwdmu1~5zj|bg(+RpI&5!%<56fciu?cvEzfSW1!p1>0u)So0<&ma zknOsc2%y_x&yeSWg8eYUYGcYA*0Hf*U60~=Ad@(p2>yibcst~R=8r}&rN#<#1JPVGgA0>v)BU7f)5-RJYwoISC5$OS5j+Gzsmu092 zYP9AFSr4BTCD_(#&1nHEZlS*CEaP(J*ms;&l*!#PddyO}g-aqSM~tYEgXlFS>c9;c z%%(4UPUc1;*2s4q!OG7u=dUP+%$&)c(w;Lb?ZAH5T?_Qx?U$G>H>moIsz<1@#z-b3 z+?7xD`Xw{yJA}?Bhz#ipuV@k699oT~hKs*jjgK80`{_H+&+$#xQQ_Yg9(V}X$eQ65 ygWoN>P4aY%V3WV1u^&-IAg=ES*KM3NzA*XwZ|VM?@7}N*PRqS{4{mU8wtfKg7tM|U literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb083c2d85da8b8b59ee3e5dbbdc522985dba335 GIT binary patch literal 3409 zcmc&$&u`l{6ecBEcHB5EvY}m9bmKYO&1itaKbLDQc@fdvE7A%nL^$Ds z^aeVN)J=5|2VM*Pi4)noS%e}Q{R6YcC0yTvu0jL!2d@e=7utPjY6#s3_Te%aQ%y7- z(L?KqOz4D;EpB1ERe$2g0t2EOJvvyV?Ua;+Nhc(pLWI6{(cxmtR+|& z?MgT<0aH}i$zxE1){Q)l%n2}U5Nn)XQe7~>ncqv|4(kmf9rh!LBk6}iv7`-a zhN{Q|q=kp=wvZ45<6XlfC<4j}zm6>&ZxzgWA@Rz|8iHB_lvsnxiZQg2E( zc@DmMTV0JprGuo`U%!r~(5_2fuAHeKjnycD0Mt6%6kn%T&H(Xdy`0*Sh*JmWr{xR+ zDlI{7J0hJEF+foQtkZJWSH9L#o(D(j6^WrW&t)n05QfcaEORkYO$X4aQW~aiW-+DD zQ`&ncMv1z71)5rbj?^mj0$sEg=|x(lCo~Q7iI-x~O*O4)zb0*J^rF&qUK58jk>8Fi z!8G44NRE=q3PsK-HcQN7I_BWo_-A6!J|PALyb5OfZG>V=Ahf1JkpZg^i=9G)#13l@ z;s9$y!J8R$ziay6ipImC-vc(QuZ%Y+i|Z^?_?WG+EztU|OycmR;P+7SH-U>Qd}-y@ zG828ZwjyqP32M`Bhp||D1lSYx_u4tFur1|FfeaX_05wxpXo{93`eZ<6?hIJjZXELdH| z5ab21zcU*vl+f055$Fcn;M^!~5%6pk&i3&;CEA0C?RQ1Y4w3=u8W|oao%9fz?-~vf z(pW`j%f^3?^uLAfb$0sMt}rmm1ygLH-iqBpxaWg4!#bST@;(J`f@XM+q zzgA99@3df^u&`No(`u-~SZO~Fgv2P)g$<>IL{WYyWg;`;%^9Iik_c+%zV!PF3VMyt zo*Csej36iDXLxUf0R4LeO(CVp!h+>m7v1`t`~seupXx*g-9{WYij!yKq)p=xNB|sG x;N8h+|Re2qs%ps*>&D$2mpX*$3Dwi&N_!q;W$14B; literal 0 HcmV?d00001 diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f88646f8f4faf00ff1a470fc1a85199182a04028 GIT binary patch literal 3468 zcmc&$U2oh(6rHiXyUsS7Mx~;(1v+Yxh*7#}Dv%IrqNYt#fy4?at%?>Z$MMYWy7k9q zW}J`8d0{KHJo3s@@1P1{5WB=Blx?eX}|y?4$zcOJG{O$1l_ z*I&G!rxE&1JxpE}JluvR9zaJC#Ubjder)UkvYs`T025r(hjN25iBj?QTU>S8BoKLTM()-5aUKp}4KYJJcH$DI9 z-6WQ5SLPqEUFN5fxttBd!1rX3#I9dX68U?RUq1@up!z<)H1woDc6S%IrE~0_QIW@e}ARprZNb{g4m0FcT>3>Cp=oHqGj#jo_67a za?gOy#=dg5x@@?qL%R!24572DvZ_l8?u%)s5LV9TGlwIXJ%Hvh%)OCAh5Nzb(7X$E9>4@ z;8&Btb+6bf;>Dy20Je)%+eJ29m%?tl!*aWib_Y_qL0~5H`~fMFOMFO^WRz^~k%R~R zAojvZtL>&<^1*;_c)qe7XaGARh0h5sRfbb@@48?#>fG8Sc3Obv6p) zokTJs2cDcKo*)3a(l@y0u!=~cgbxRSh~~)-gDE2DmobQR^LiOa;RKl0fYwgWi2)d( z#EX)+PogxG!7zk4l3~CENqVqmAgVk-T4>PgF%B^h%C-Za38!5K&y%!2D5jSS%umcN z4f?PjdPM>rtL*duVwy>15+5YLlftq>jTOkHx#S}!=~U`-Xr$c^@m9J^yS@-y&N-neCvQZiRZ270!5s>rl&g$aZq zmf@!QKEHellsir%H$xWZmYSb83LuEQ1{vYNDko68v{>xQgVJ4Tybo!g6p0s z674E5cAGIT)kbtJz^KS^kk<=~IewDk=(Q-his~{nF#{cH*?1P8GiLF5+{AAPpUSGb z1c7cWd|mSe>S3)IcGr4=EYwv@w|EVvd1gg-rM;w*?viYi6g?tl3A2@dCJVh|vQPo7 z;<(pSpxk5%WMj!FAXdr7Mx{n-r%mu3NHP z5uCpd-EC+>RYZi(VH-cYdQyJcM=LsY-hlaB)lxnMUEY9fBtFlkngz- zMH)(YhFp_E&K%2YRm}T3iJ!8)>-vbJsx8|Ep-~Ljn2Z)m>gE&X%Qjh|R4Z^5@?=#m zcJrZNy)+~{0~V9LBqaka%c+n_qyYN5;XnafRn^Hd^55hAZ()3yoO~3l3`{aY7hkBq zY%2}6JfHyey0rw0jm6zZZlu89i z$J7Mlim@fLx~(LnHn;F-}>JX<_D$>Rcg%tFVh=gmL_v5;Qu zGp=IEXI6w@TnY3&&J$h`a7v4Il7vu0cepncP~uy3@+8SWQW0oUUQ>6Plu8BGpb0gO zW@e1KVdI&aGsQoK=k67maDUK_<92l(Z6CL(9RewUBMcn%>jjzg5kn_W3l~?5Oq5{l gf=-uxtEke;>Pei3&Nik^1MA 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + return mean, torch.rsqrt(bias_var + self.eps) + # return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file diff --git a/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py b/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000..7afcdaf --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNormReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/text2image/BigGAN_utils/sync_batchnorm/comm.py b/text2image/BigGAN_utils/sync_batchnorm/comm.py new file mode 100644 index 0000000..922f8c4 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/text2image/BigGAN_utils/sync_batchnorm/replicate.py b/text2image/BigGAN_utils/sync_batchnorm/replicate.py new file mode 100644 index 0000000..b71c7b8 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/text2image/BigGAN_utils/sync_batchnorm/unittest.py b/text2image/BigGAN_utils/sync_batchnorm/unittest.py new file mode 100644 index 0000000..bed56f1 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/text2image/BigGAN_utils/train.py b/text2image/BigGAN_utils/train.py new file mode 100644 index 0000000..4fb804c --- /dev/null +++ b/text2image/BigGAN_utils/train.py @@ -0,0 +1,227 @@ +""" BigGAN: The Authorized Unofficial PyTorch release + Code by A. Brock and A. Andonian + This code is an unofficial reimplementation of + "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," + by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). + + Let's go. +""" + +import os +import functools +import math +import numpy as np +from tqdm import tqdm, trange + + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P +import torchvision + +# Import my stuff +import inception_utils +import utils +import losses +import train_fns +from sync_batchnorm import patch_replication_callback + +# The main training file. Config is a dictionary specifying the configuration +# of this training run. +def run(config): + + # Update the config dict as necessary + # This is for convenience, to add settings derived from the user-specified + # configuration into the config-dict (e.g. inferring the number of classes + # and size of the images from the dataset, passing in a pytorch object + # for the activation specified as a string) + config['resolution'] = utils.imsize_dict[config['dataset']] + config['n_classes'] = utils.nclass_dict[config['dataset']] + config['G_activation'] = utils.activation_dict[config['G_nl']] + config['D_activation'] = utils.activation_dict[config['D_nl']] + # By default, skip init if resuming training. + if config['resume']: + print('Skipping initialization for training resumption...') + config['skip_init'] = True + config = utils.update_config_roots(config) + device = 'cuda' + + # Seed RNG + utils.seed_rng(config['seed']) + + # Prepare root folders if necessary + utils.prepare_root(config) + + # Setup cudnn.benchmark for free speed + torch.backends.cudnn.benchmark = True + + # Import the model--this line allows us to dynamically select different files. + model = __import__(config['model']) + experiment_name = (config['experiment_name'] if config['experiment_name'] + else utils.name_from_config(config)) + print('Experiment name is %s' % experiment_name) + + # Next, build the model + G = model.Generator(**config).to(device) + D = model.Discriminator(**config).to(device) + + # If using EMA, prepare it + if config['ema']: + print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) + G_ema = model.Generator(**{**config, 'skip_init':True, + 'no_optim': True}).to(device) + ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) + else: + G_ema, ema = None, None + + # FP16? + if config['G_fp16']: + print('Casting G to float16...') + G = G.half() + if config['ema']: + G_ema = G_ema.half() + if config['D_fp16']: + print('Casting D to fp16...') + D = D.half() + # Consider automatically reducing SN_eps? + GD = model.G_D(G, D) + print(G) + print(D) + print('Number of params in G: {} D: {}'.format( + *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]])) + # Prepare state dict, which holds things like epoch # and itr # + state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, + 'best_IS': 0, 'best_FID': 999999, 'config': config} + + # If loading from a pre-trained model, load weights + if config['resume']: + print('Loading weights...') + utils.load_weights(G, D, state_dict, + config['weights_root'], experiment_name, + config['load_weights'] if config['load_weights'] else None, + G_ema if config['ema'] else None) + + # If parallel, parallelize the GD module + if config['parallel']: + GD = nn.DataParallel(GD) + if config['cross_replica']: + patch_replication_callback(GD) + + # Prepare loggers for stats; metrics holds test metrics, + # lmetrics holds any desired training metrics. + test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], + experiment_name) + train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) + print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) + test_log = utils.MetricsLogger(test_metrics_fname, + reinitialize=(not config['resume'])) + print('Training Metrics will be saved to {}'.format(train_metrics_fname)) + train_log = utils.MyLogger(train_metrics_fname, + reinitialize=(not config['resume']), + logstyle=config['logstyle']) + # Write metadata + utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) + # Prepare data; the Discriminator's batch size is all that needs to be passed + # to the dataloader, as G doesn't require dataloading. + # Note that at every loader iteration we pass in enough data to complete + # a full D iteration (regardless of number of D steps and accumulations) + D_batch_size = (config['batch_size'] * config['num_D_steps'] + * config['num_D_accumulations']) + loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, + 'start_itr': state_dict['itr']}) + + # Prepare inception metrics: FID and IS + get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) + + # Prepare noise and randomly sampled label arrays + # Allow for different batch sizes in G + G_batch_size = max(config['G_batch_size'], config['batch_size']) + z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], + device=device, fp16=config['G_fp16']) + # Prepare a fixed z & y to see individual sample evolution throghout training + fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, + config['n_classes'], device=device, + fp16=config['G_fp16']) + fixed_z.sample_() + fixed_y.sample_() + # Loaders are loaded, prepare the training function + if config['which_train_fn'] == 'GAN': + train = train_fns.GAN_training_function(G, D, GD, z_, y_, + ema, state_dict, config) + # Else, assume debugging and use the dummy train fn + else: + train = train_fns.dummy_training_function() + # Prepare Sample function for use with inception metrics + sample = functools.partial(utils.sample, + G=(G_ema if config['ema'] and config['use_ema'] + else G), + z_=z_, y_=y_, config=config) + + print('Beginning training at epoch %d...' % state_dict['epoch']) + # Train for specified number of epochs, although we mostly track G iterations. + for epoch in range(state_dict['epoch'], config['num_epochs']): + # Which progressbar to use? TQDM or my own? + if config['pbar'] == 'mine': + pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') + else: + pbar = tqdm(loaders[0]) + for i, (x, y) in enumerate(pbar): + # Increment the iteration counter + state_dict['itr'] += 1 + # Make sure G and D are in training mode, just in case they got set to eval + # For D, which typically doesn't have BN, this shouldn't matter much. + G.train() + D.train() + if config['ema']: + G_ema.train() + if config['D_fp16']: + x, y = x.to(device).half(), y.to(device) + else: + x, y = x.to(device), y.to(device) + metrics = train(x, y) + train_log.log(itr=int(state_dict['itr']), **metrics) + + # Every sv_log_interval, log singular values + if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): + train_log.log(itr=int(state_dict['itr']), + **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) + + # If using my progbar, print metrics. + if config['pbar'] == 'mine': + print(', '.join(['itr: %d' % state_dict['itr']] + + ['%s : %+4.3f' % (key, metrics[key]) + for key in metrics]), end=' ') + + # Save weights and copies as configured at specified interval + if not (state_dict['itr'] % config['save_every']): + if config['G_eval_mode']: + print('Switchin G to eval mode...') + G.eval() + if config['ema']: + G_ema.eval() + train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, + state_dict, config, experiment_name) + + # Test every specified interval + if not (state_dict['itr'] % config['test_every']): + if config['G_eval_mode']: + print('Switchin G to eval mode...') + G.eval() + train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, + get_inception_metrics, experiment_name, test_log) + # Increment epoch counter at end of epoch + state_dict['epoch'] += 1 + + +def main(): + # parse command line and run + parser = utils.prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/train_fns.py b/text2image/BigGAN_utils/train_fns.py new file mode 100644 index 0000000..e03e530 --- /dev/null +++ b/text2image/BigGAN_utils/train_fns.py @@ -0,0 +1,187 @@ +''' train_fns.py +Functions for the main loop of training different conditional image models +''' +import torch +import torch.nn as nn +import torchvision +import os + +import utils +import losses + + +# Dummy training function for debugging +def dummy_training_function(): + def train(x, y): + return {} + return train + + +def GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config): + def train(x, y): + G.optim.zero_grad() + D.optim.zero_grad() + # How many chunks to split x and y into? + x = torch.split(x, config['batch_size']) + y = torch.split(y, config['batch_size']) + counter = 0 + + # Optionally toggle D and G's "require_grad" + if config['toggle_grads']: + utils.toggle_grad(D, True) + utils.toggle_grad(G, False) + + for step_index in range(config['num_D_steps']): + # If accumulating gradients, loop multiple times before an optimizer step + D.optim.zero_grad() + for accumulation_index in range(config['num_D_accumulations']): + z_.sample_() + y_.sample_() + D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], + x[counter], y[counter], train_G=False, + split_D=config['split_D']) + + # Compute components of D's loss, average them, and divide by + # the number of gradient accumulations + D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) + D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) + D_loss.backward() + counter += 1 + + # Optionally apply ortho reg in D + if config['D_ortho'] > 0.0: + # Debug print to indicate we're using ortho reg in D. + print('using modified ortho reg in D') + utils.ortho(D, config['D_ortho']) + + D.optim.step() + + # Optionally toggle "requires_grad" + if config['toggle_grads']: + utils.toggle_grad(D, False) + utils.toggle_grad(G, True) + + # Zero G's gradients by default before training G, for safety + G.optim.zero_grad() + + # If accumulating gradients, loop multiple times + for accumulation_index in range(config['num_G_accumulations']): + z_.sample_() + y_.sample_() + D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) + G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) + G_loss.backward() + + # Optionally apply modified ortho reg in G + if config['G_ortho'] > 0.0: + print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G + # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this + utils.ortho(G, config['G_ortho'], + blacklist=[param for param in G.shared.parameters()]) + G.optim.step() + + # If we have an ema, update it, regardless of if we test with it or not + if config['ema']: + ema.update(state_dict['itr']) + + out = {'G_loss': float(G_loss.item()), + 'D_loss_real': float(D_loss_real.item()), + 'D_loss_fake': float(D_loss_fake.item())} + # Return G's loss and the components of D's loss. + return out + return train + +''' This function takes in the model, saves the weights (multiple copies if + requested), and prepares sample sheets: one consisting of samples given + a fixed noise seed (to show how the model evolves throughout training), + a set of full conditional sample sheets, and a set of interp sheets. ''' +def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, + state_dict, config, experiment_name): + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, None, G_ema if config['ema'] else None) + # Save an additional copy to mitigate accidental corruption if process + # is killed during a save (it's happened to me before -.-) + if config['num_save_copies'] > 0: + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, + 'copy%d' % state_dict['save_num'], + G_ema if config['ema'] else None) + state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies'] + + # Use EMA G for samples or non-EMA? + which_G = G_ema if config['ema'] and config['use_ema'] else G + + # Accumulate standing statistics? + if config['accumulate_stats']: + utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, + z_, y_, config['n_classes'], + config['num_standing_accumulations']) + + # Save a random sample sheet with fixed z and y + with torch.no_grad(): + if config['parallel']: + fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y))) + else: + fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) + if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): + os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) + image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], + experiment_name, + state_dict['itr']) + torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, ## NOTE: xcliu for torchvision 0.8.2 + nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) + #torchvision.utils.save_image(torch.from_numpy(fixed_Gz.float().cpu().numpy()), image_filename, + # nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) + + # For now, every time we save, also save sample sheets + utils.sample_sheet(which_G, + classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], + num_classes=config['n_classes'], + samples_per_class=10, parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=state_dict['itr'], + z_=z_) + # Also save interp sheets + for fix_z, fix_y in zip([False, False, True], [False, True, False]): + utils.interp_sheet(which_G, + num_per_sheet=16, + num_midpoints=8, + num_classes=config['n_classes'], + parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=state_dict['itr'], + sheet_number=0, + fix_z=fix_z, fix_y=fix_y, device='cuda') + + + +''' This function runs the inception metrics code, checks if the results + are an improvement over the previous best (either in IS or FID, + user-specified), logs the results, and saves a best_ copy if it's an + improvement. ''' +def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, + experiment_name, test_log): + print('Gathering inception metrics...') + if config['accumulate_stats']: + utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, + z_, y_, config['n_classes'], + config['num_standing_accumulations']) + IS_mean, IS_std, FID = get_inception_metrics(sample, + config['num_inception_images'], + num_splits=10) + print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID)) + # If improved over previous best metric, save approrpiate copy + if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) + or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])): + print('%s improved over previous best, saving checkpoint...' % config['which_best']) + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, 'best%d' % state_dict['save_best_num'], + G_ema if config['ema'] else None) + state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies'] + state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean) + state_dict['best_FID'] = min(state_dict['best_FID'], FID) + # Log results to file + test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean), + IS_std=float(IS_std), FID=float(FID)) diff --git a/text2image/BigGAN_utils/utils.py b/text2image/BigGAN_utils/utils.py new file mode 100644 index 0000000..9d79f33 --- /dev/null +++ b/text2image/BigGAN_utils/utils.py @@ -0,0 +1,1195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +''' Utilities file +This file contains utility functions for bookkeeping, logging, and data loading. +Methods which directly affect training should either go in layers, the model, +or train_fns.py. +''' + +from __future__ import print_function +import sys +import os +import numpy as np +import time +import datetime +import json +import pickle +from argparse import ArgumentParser +import animal_hash +import datasets as dset + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +def prepare_parser(): + usage = 'Parser for all scripts.' + parser = ArgumentParser(description=usage) + + ### Dataset/Dataloader stuff ### + parser.add_argument( + '--dataset', type=str, default='I128_hdf5', + help='Which Dataset to train on, out of I128, I256, C10, C100;' + 'Append "_hdf5" to use the hdf5 version for ISLVRC ' + '(default: %(default)s)') + parser.add_argument( + '--augment', action='store_true', default=False, + help='Augment with random crops and flips (default: %(default)s)') + parser.add_argument( + '--num_workers', type=int, default=8, + help='Number of dataloader workers; consider using less for HDF5 ' + '(default: %(default)s)') + parser.add_argument( + '--no_pin_memory', action='store_false', dest='pin_memory', default=True, + help='Pin data into memory through dataloader? (default: %(default)s)') + parser.add_argument( + '--shuffle', action='store_true', default=False, + help='Shuffle the data (strongly recommended)? (default: %(default)s)') + parser.add_argument( + '--load_in_mem', action='store_true', default=False, + help='Load all data into memory? (default: %(default)s)') + parser.add_argument( + '--use_multiepoch_sampler', action='store_true', default=False, + help='Use the multi-epoch sampler for dataloader? (default: %(default)s)') + + + ### Model stuff ### + parser.add_argument( + '--model', type=str, default='BigGAN', + help='Name of the model module (default: %(default)s)') + parser.add_argument( + '--G_param', type=str, default='SN', + help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)' + ' or None (default: %(default)s)') + parser.add_argument( + '--D_param', type=str, default='SN', + help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)' + ' or None (default: %(default)s)') + parser.add_argument( + '--G_ch', type=int, default=64, + help='Channel multiplier for G (default: %(default)s)') + parser.add_argument( + '--D_ch', type=int, default=64, + help='Channel multiplier for D (default: %(default)s)') + parser.add_argument( + '--G_depth', type=int, default=1, + help='Number of resblocks per stage in G? (default: %(default)s)') + parser.add_argument( + '--D_depth', type=int, default=1, + help='Number of resblocks per stage in D? (default: %(default)s)') + parser.add_argument( + '--D_thin', action='store_false', dest='D_wide', default=True, + help='Use the SN-GAN channel pattern for D? (default: %(default)s)') + parser.add_argument( + '--G_shared', action='store_true', default=False, + help='Use shared embeddings in G? (default: %(default)s)') + parser.add_argument( + '--shared_dim', type=int, default=0, + help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. ' + '(default: %(default)s)') + parser.add_argument( + '--dim_z', type=int, default=128, + help='Noise dimensionality: %(default)s)') + parser.add_argument( + '--z_var', type=float, default=1.0, + help='Noise variance: %(default)s)') + parser.add_argument( + '--hier', action='store_true', default=False, + help='Use hierarchical z in G? (default: %(default)s)') + parser.add_argument( + '--cross_replica', action='store_true', default=False, + help='Cross_replica batchnorm in G?(default: %(default)s)') + parser.add_argument( + '--mybn', action='store_true', default=False, + help='Use my batchnorm (which supports standing stats?) %(default)s)') + parser.add_argument( + '--G_nl', type=str, default='relu', + help='Activation function for G (default: %(default)s)') + parser.add_argument( + '--D_nl', type=str, default='relu', + help='Activation function for D (default: %(default)s)') + parser.add_argument( + '--G_attn', type=str, default='64', + help='What resolutions to use attention on for G (underscore separated) ' + '(default: %(default)s)') + parser.add_argument( + '--D_attn', type=str, default='64', + help='What resolutions to use attention on for D (underscore separated) ' + '(default: %(default)s)') + parser.add_argument( + '--norm_style', type=str, default='bn', + help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], ' + 'ln [layernorm], gn [groupnorm] (default: %(default)s)') + + ### Model init stuff ### + parser.add_argument( + '--seed', type=int, default=0, + help='Random seed to use; affects both initialization and ' + ' dataloading. (default: %(default)s)') + parser.add_argument( + '--G_init', type=str, default='ortho', + help='Init style to use for G (default: %(default)s)') + parser.add_argument( + '--D_init', type=str, default='ortho', + help='Init style to use for D(default: %(default)s)') + parser.add_argument( + '--skip_init', action='store_true', default=False, + help='Skip initialization, ideal for testing when ortho init was used ' + '(default: %(default)s)') + + ### Optimizer stuff ### + parser.add_argument( + '--G_lr', type=float, default=5e-5, + help='Learning rate to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_lr', type=float, default=2e-4, + help='Learning rate to use for Discriminator (default: %(default)s)') + parser.add_argument( + '--G_B1', type=float, default=0.0, + help='Beta1 to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_B1', type=float, default=0.0, + help='Beta1 to use for Discriminator (default: %(default)s)') + parser.add_argument( + '--G_B2', type=float, default=0.999, + help='Beta2 to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_B2', type=float, default=0.999, + help='Beta2 to use for Discriminator (default: %(default)s)') + + ### Batch size, parallel, and precision stuff ### + parser.add_argument( + '--batch_size', type=int, default=64, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--G_batch_size', type=int, default=0, + help='Batch size to use for G; if 0, same as D (default: %(default)s)') + parser.add_argument( + '--num_G_accumulations', type=int, default=1, + help='Number of passes to accumulate G''s gradients over ' + '(default: %(default)s)') + parser.add_argument( + '--num_D_steps', type=int, default=2, + help='Number of D steps per G step (default: %(default)s)') + parser.add_argument( + '--num_D_accumulations', type=int, default=1, + help='Number of passes to accumulate D''s gradients over ' + '(default: %(default)s)') + parser.add_argument( + '--split_D', action='store_true', default=False, + help='Run D twice rather than concatenating inputs? (default: %(default)s)') + parser.add_argument( + '--num_epochs', type=int, default=100, + help='Number of epochs to train for (default: %(default)s)') + parser.add_argument( + '--parallel', action='store_true', default=False, + help='Train with multiple GPUs (default: %(default)s)') + parser.add_argument( + '--G_fp16', action='store_true', default=False, + help='Train with half-precision in G? (default: %(default)s)') + parser.add_argument( + '--D_fp16', action='store_true', default=False, + help='Train with half-precision in D? (default: %(default)s)') + parser.add_argument( + '--D_mixed_precision', action='store_true', default=False, + help='Train with half-precision activations but fp32 params in D? ' + '(default: %(default)s)') + parser.add_argument( + '--G_mixed_precision', action='store_true', default=False, + help='Train with half-precision activations but fp32 params in G? ' + '(default: %(default)s)') + parser.add_argument( + '--accumulate_stats', action='store_true', default=False, + help='Accumulate "standing" batchnorm stats? (default: %(default)s)') + parser.add_argument( + '--num_standing_accumulations', type=int, default=16, + help='Number of forward passes to use in accumulating standing stats? ' + '(default: %(default)s)') + + ### Bookkeping stuff ### + parser.add_argument( + '--G_eval_mode', action='store_true', default=False, + help='Run G in eval mode (running/standing stats?) at sample/test time? ' + '(default: %(default)s)') + parser.add_argument( + '--save_every', type=int, default=2000, + help='Save every X iterations (default: %(default)s)') + parser.add_argument( + '--num_save_copies', type=int, default=2, + help='How many copies to save (default: %(default)s)') + parser.add_argument( + '--num_best_copies', type=int, default=2, + help='How many previous best checkpoints to save (default: %(default)s)') + parser.add_argument( + '--which_best', type=str, default='IS', + help='Which metric to use to determine when to save new "best"' + 'checkpoints, one of IS or FID (default: %(default)s)') + parser.add_argument( + '--no_fid', action='store_true', default=False, + help='Calculate IS only, not FID? (default: %(default)s)') + parser.add_argument( + '--test_every', type=int, default=5000, + help='Test every X iterations (default: %(default)s)') + parser.add_argument( + '--num_inception_images', type=int, default=50000, + help='Number of samples to compute inception metrics with ' + '(default: %(default)s)') + parser.add_argument( + '--hashname', action='store_true', default=False, + help='Use a hash of the experiment name instead of the full config ' + '(default: %(default)s)') + parser.add_argument( + '--base_root', type=str, default='', + help='Default location to store all weights, samples, data, and logs ' + ' (default: %(default)s)') + parser.add_argument( + '--data_root', type=str, default='data', + help='Default location where data is stored (default: %(default)s)') + parser.add_argument( + '--weights_root', type=str, default='weights', + help='Default location to store weights (default: %(default)s)') + parser.add_argument( + '--logs_root', type=str, default='logs', + help='Default location to store logs (default: %(default)s)') + parser.add_argument( + '--samples_root', type=str, default='samples', + help='Default location to store samples (default: %(default)s)') + parser.add_argument( + '--pbar', type=str, default='mine', + help='Type of progressbar to use; one of "mine" or "tqdm" ' + '(default: %(default)s)') + parser.add_argument( + '--name_suffix', type=str, default='', + help='Suffix for experiment name for loading weights for sampling ' + '(consider "best0") (default: %(default)s)') + parser.add_argument( + '--experiment_name', type=str, default='', + help='Optionally override the automatic experiment naming with this arg. ' + '(default: %(default)s)') + parser.add_argument( + '--config_from_name', action='store_true', default=False, + help='Use a hash of the experiment name instead of the full config ' + '(default: %(default)s)') + + ### EMA Stuff ### + parser.add_argument( + '--ema', action='store_true', default=False, + help='Keep an ema of G''s weights? (default: %(default)s)') + parser.add_argument( + '--ema_decay', type=float, default=0.9999, + help='EMA decay rate (default: %(default)s)') + parser.add_argument( + '--use_ema', action='store_true', default=False, + help='Use the EMA parameters of G for evaluation? (default: %(default)s)') + parser.add_argument( + '--ema_start', type=int, default=0, + help='When to start updating the EMA weights (default: %(default)s)') + + ### Numerical precision and SV stuff ### + parser.add_argument( + '--adam_eps', type=float, default=1e-8, + help='epsilon value to use for Adam (default: %(default)s)') + parser.add_argument( + '--BN_eps', type=float, default=1e-5, + help='epsilon value to use for BatchNorm (default: %(default)s)') + parser.add_argument( + '--SN_eps', type=float, default=1e-8, + help='epsilon value to use for Spectral Norm(default: %(default)s)') + parser.add_argument( + '--num_G_SVs', type=int, default=1, + help='Number of SVs to track in G (default: %(default)s)') + parser.add_argument( + '--num_D_SVs', type=int, default=1, + help='Number of SVs to track in D (default: %(default)s)') + parser.add_argument( + '--num_G_SV_itrs', type=int, default=1, + help='Number of SV itrs in G (default: %(default)s)') + parser.add_argument( + '--num_D_SV_itrs', type=int, default=1, + help='Number of SV itrs in D (default: %(default)s)') + + ### Ortho reg stuff ### + parser.add_argument( + '--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN + help='Modified ortho reg coefficient in G(default: %(default)s)') + parser.add_argument( + '--D_ortho', type=float, default=0.0, + help='Modified ortho reg coefficient in D (default: %(default)s)') + parser.add_argument( + '--toggle_grads', action='store_true', default=True, + help='Toggle D and G''s "requires_grad" settings when not training them? ' + ' (default: %(default)s)') + + ### Which train function ### + parser.add_argument( + '--which_train_fn', type=str, default='GAN', + help='How2trainyourbois (default: %(default)s)') + + ### Resume training stuff + parser.add_argument( + '--load_weights', type=str, default='', + help='Suffix for which weights to load (e.g. best0, copy0) ' + '(default: %(default)s)') + parser.add_argument( + '--resume', action='store_true', default=False, + help='Resume training? (default: %(default)s)') + + ### Log stuff ### + parser.add_argument( + '--logstyle', type=str, default='%3.3e', + help='What style to use when logging training metrics?' + 'One of: %#.#f/ %#.#e (float/exp, text),' + 'pickle (python pickle),' + 'npz (numpy zip),' + 'mat (MATLAB .mat file) (default: %(default)s)') + parser.add_argument( + '--log_G_spectra', action='store_true', default=False, + help='Log the top 3 singular values in each SN layer in G? ' + '(default: %(default)s)') + parser.add_argument( + '--log_D_spectra', action='store_true', default=False, + help='Log the top 3 singular values in each SN layer in D? ' + '(default: %(default)s)') + parser.add_argument( + '--sv_log_interval', type=int, default=10, + help='Iteration interval for logging singular values ' + ' (default: %(default)s)') + + parser.add_argument('--text', type=str) + + return parser + +# Arguments for sample.py; not presently used in train.py +def add_sample_parser(parser): + parser.add_argument( + '--sample_npz', action='store_true', default=False, + help='Sample "sample_num_npz" images and save to npz? ' + '(default: %(default)s)') + parser.add_argument( + '--sample_num_npz', type=int, default=50000, + help='Number of images to sample when sampling NPZs ' + '(default: %(default)s)') + parser.add_argument( + '--sample_sheets', action='store_true', default=False, + help='Produce class-conditional sample sheets and stick them in ' + 'the samples root? (default: %(default)s)') + parser.add_argument( + '--sample_interps', action='store_true', default=False, + help='Produce interpolation sheets and stick them in ' + 'the samples root? (default: %(default)s)') + parser.add_argument( + '--sample_sheet_folder_num', type=int, default=-1, + help='Number to use for the folder for these sample sheets ' + '(default: %(default)s)') + parser.add_argument( + '--sample_random', action='store_true', default=False, + help='Produce a single random sheet? (default: %(default)s)') + parser.add_argument( + '--sample_trunc_curves', type=str, default='', + help='Get inception metrics with a range of variances?' + 'To use this, specify a startpoint, step, and endpoint, e.g. ' + '--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, ' + 'endpoint of 1.0, and stepsize of 1.0. Note that this is ' + 'not exactly identical to using tf.truncated_normal, but should ' + 'have approximately the same effect. (default: %(default)s)') + parser.add_argument( + '--sample_inception_metrics', action='store_true', default=False, + help='Calculate Inception metrics with sample.py? (default: %(default)s)') + return parser + +# Convenience dicts +dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder, + 'I128': dset.ImageFolder, 'I256': dset.ImageFolder, + 'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5, + 'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5, + 'C10': dset.CIFAR10, 'C100': dset.CIFAR100} +imsize_dict = {'I32': 32, 'I32_hdf5': 32, + 'I64': 64, 'I64_hdf5': 64, + 'I128': 128, 'I128_hdf5': 128, + 'I256': 256, 'I256_hdf5': 256, + 'C10': 32, 'C100': 32} +root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5', + 'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5', + 'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5', + 'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5', + 'C10': 'cifar', 'C100': 'cifar'} +nclass_dict = {'I32': 1000, 'I32_hdf5': 1000, + 'I64': 1000, 'I64_hdf5': 1000, + 'I128': 1000, 'I128_hdf5': 1000, + 'I256': 1000, 'I256_hdf5': 1000, + 'C10': 10, 'C100': 100} +# Number of classes to put per sample sheet +classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50, + 'I64': 50, 'I64_hdf5': 50, + 'I128': 20, 'I128_hdf5': 20, + 'I256': 20, 'I256_hdf5': 20, + 'C10': 10, 'C100': 100} +activation_dict = {'inplace_relu': nn.ReLU(inplace=True), + 'relu': nn.ReLU(inplace=False), + 'ir': nn.ReLU(inplace=True),} + +class CenterCropLongEdge(object): + """Crops the given PIL Image on the long edge. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + Returns: + PIL Image: Cropped image. + """ + return transforms.functional.center_crop(img, min(img.size)) + + def __repr__(self): + return self.__class__.__name__ + +class RandomCropLongEdge(object): + """Crops the given PIL Image on the long edge with a random start point. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + Returns: + PIL Image: Cropped image. + """ + size = (min(img.size), min(img.size)) + # Only step forward along this edge if it's the long edge + i = (0 if size[0] == img.size[0] + else np.random.randint(low=0,high=img.size[0] - size[0])) + j = (0 if size[1] == img.size[1] + else np.random.randint(low=0,high=img.size[1] - size[1])) + return transforms.functional.crop(img, i, j, size[0], size[1]) + + def __repr__(self): + return self.__class__.__name__ + + +# multi-epoch Dataset sampler to avoid memory leakage and enable resumption of +# training from the same sample regardless of if we stop mid-epoch +class MultiEpochSampler(torch.utils.data.Sampler): + r"""Samples elements randomly over multiple epochs + + Arguments: + data_source (Dataset): dataset to sample from + num_epochs (int) : Number of times to loop over the dataset + start_itr (int) : which iteration to begin from + """ + + def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128): + self.data_source = data_source + self.num_samples = len(self.data_source) + self.num_epochs = num_epochs + self.start_itr = start_itr + self.batch_size = batch_size + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integeral " + "value, but got num_samples={}".format(self.num_samples)) + + def __iter__(self): + n = len(self.data_source) + # Determine number of epochs + num_epochs = int(np.ceil((n * self.num_epochs + - (self.start_itr * self.batch_size)) / float(n))) + # Sample all the indices, and then grab the last num_epochs index sets; + # This ensures if we're starting at epoch 4, we're still grabbing epoch 4's + # indices + out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:] + # Ignore the first start_itr % n indices of the first epoch + out[0] = out[0][(self.start_itr * self.batch_size % n):] + # if self.replacement: + # return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) + # return iter(.tolist()) + output = torch.cat(out).tolist() + print('Length dataset output is %d' % len(output)) + return iter(output) + + def __len__(self): + return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size + + +# Convenience function to centralize all data loaders +def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64, + num_workers=8, shuffle=True, load_in_mem=False, hdf5=False, + pin_memory=True, drop_last=True, start_itr=0, + num_epochs=500, use_multiepoch_sampler=False, + **kwargs): + + # Append /FILENAME.hdf5 to root if using hdf5 + data_root += '/%s' % root_dict[dataset] + print('Using dataset root location %s' % data_root) + + which_dataset = dset_dict[dataset] + norm_mean = [0.5,0.5,0.5] + norm_std = [0.5,0.5,0.5] + image_size = imsize_dict[dataset] + # For image folder datasets, name of the file where we store the precomputed + # image locations to avoid having to walk the dirs every time we load. + dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset} + + # HDF5 datasets have their own inbuilt transform, no need to train_transform + if 'hdf5' in dataset: + train_transform = None + else: + if augment: + print('Data will be augmented...') + if dataset in ['C10', 'C100']: + train_transform = [transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip()] + else: + train_transform = [RandomCropLongEdge(), + transforms.Resize(image_size), + transforms.RandomHorizontalFlip()] + else: + print('Data will not be augmented...') + if dataset in ['C10', 'C100']: + train_transform = [] + else: + train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)] + # train_transform = [transforms.Resize(image_size), transforms.CenterCrop] + train_transform = transforms.Compose(train_transform + [ + transforms.ToTensor(), + transforms.Normalize(norm_mean, norm_std)]) + train_set = which_dataset(root=data_root, transform=train_transform, + load_in_mem=load_in_mem, **dataset_kwargs) + + # Prepare loader; the loaders list is for forward compatibility with + # using validation / test splits. + loaders = [] + if use_multiepoch_sampler: + print('Using multiepoch sampler from start_itr %d...' % start_itr) + loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} + sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size) + train_loader = DataLoader(train_set, batch_size=batch_size, + sampler=sampler, **loader_kwargs) + else: + loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory, + 'drop_last': drop_last} # Default, drop last incomplete batch + train_loader = DataLoader(train_set, batch_size=batch_size, + shuffle=shuffle, **loader_kwargs) + loaders.append(train_loader) + return loaders + + +# Utility file to seed rngs +def seed_rng(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + + +# Utility to peg all roots to a base root +# If a base root folder is provided, peg all other root folders to it. +def update_config_roots(config): + if config['base_root']: + print('Pegging all root folders to base root %s' % config['base_root']) + for key in ['data', 'weights', 'logs', 'samples']: + config['%s_root' % key] = '%s/%s' % (config['base_root'], key) + return config + + +# Utility to prepare root folders if they don't exist; parent folder must exist +def prepare_root(config): + for key in ['weights_root', 'logs_root', 'samples_root']: + if not os.path.exists(config[key]): + print('Making directory %s for %s...' % (config[key], key)) + os.mkdir(config[key]) + + +# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using +# the parameters() and buffers() module functions, but for now this works +# with state_dicts using .copy_ +class ema(object): + def __init__(self, source, target, decay=0.9999, start_itr=0): + self.source = source + self.target = target + self.decay = decay + # Optional parameter indicating what iteration to start the decay at + self.start_itr = start_itr + # Initialize target's params to be source's + self.source_dict = self.source.state_dict() + self.target_dict = self.target.state_dict() + print('Initializing EMA parameters to be source parameters...') + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.source_dict[key].data) + # target_dict[key].data = source_dict[key].data # Doesn't work! + + def update(self, itr=None): + # If an iteration counter is provided and itr is less than the start itr, + # peg the ema weights to the underlying weights. + if itr and itr < self.start_itr: + decay = 0.0 + else: + decay = self.decay + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.target_dict[key].data * decay + + self.source_dict[key].data * (1 - decay)) + + +# Apply modified ortho reg to a model +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes, and not in the blacklist + if len(param.shape) < 2 or any([param is item for item in blacklist]): + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + * (1. - torch.eye(w.shape[0], device=w.device)), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Default ortho reg +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def default_ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes & not in blacklist + if len(param.shape) < 2 or param in blacklist: + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + - torch.eye(w.shape[0], device=w.device), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Convenience utility to switch off requires_grad +def toggle_grad(model, on_or_off): + for param in model.parameters(): + param.requires_grad = on_or_off + + +# Function to join strings or ignore them +# Base string is the string to link "strings," while strings +# is a list of strings or Nones. +def join_strings(base_string, strings): + return base_string.join([item for item in strings if item]) + + +# Save a model's weights, optimizer, and the state_dict +def save_weights(G, D, state_dict, weights_root, experiment_name, + name_suffix=None, G_ema=None): + root = '/'.join([weights_root, experiment_name]) + if not os.path.exists(root): + os.mkdir(root) + if name_suffix: + print('Saving weights to %s/%s...' % (root, name_suffix)) + else: + print('Saving weights to %s...' % root) + torch.save(G.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))) + torch.save(G.optim.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))) + torch.save(D.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))) + torch.save(D.optim.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))) + torch.save(state_dict, + '%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix]))) + if G_ema is not None: + torch.save(G_ema.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))) + + +# Load a model's weights, optimizer, and the state_dict +def load_weights(G, D, state_dict, weights_root, experiment_name, + name_suffix=None, G_ema=None, strict=True, load_optim=True): + root = '/'.join([weights_root, experiment_name]) + if name_suffix: + print('Loading %s weights from %s...' % (name_suffix, root)) + else: + print('Loading weights from %s...' % root) + if G is not None: + G.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))), + strict=strict) + if load_optim: + G.optim.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))) + if D is not None: + D.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))), + strict=strict) + if load_optim: + D.optim.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))) + # Load state dict + for item in state_dict: + state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item] + if G_ema is not None: + G_ema.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))), + strict=strict) + + +''' MetricsLogger originally stolen from VoxNet source code. + Used for logging inception metrics''' +class MetricsLogger(object): + def __init__(self, fname, reinitialize=False): + self.fname = fname + self.reinitialize = reinitialize + if os.path.exists(self.fname): + if self.reinitialize: + print('{} exists, deleting...'.format(self.fname)) + os.remove(self.fname) + + def log(self, record=None, **kwargs): + """ + Assumption: no newlines in the input. + """ + if record is None: + record = {} + record.update(kwargs) + record['_stamp'] = time.time() + with open(self.fname, 'a') as f: + f.write(json.dumps(record, ensure_ascii=True) + '\n') + + +# Logstyle is either: +# '%#.#f' for floating point representation in text +# '%#.#e' for exponent representation in text +# 'npz' for output to npz # NOT YET SUPPORTED +# 'pickle' for output to a python pickle # NOT YET SUPPORTED +# 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED +class MyLogger(object): + def __init__(self, fname, reinitialize=False, logstyle='%3.3f'): + self.root = fname + if not os.path.exists(self.root): + os.mkdir(self.root) + self.reinitialize = reinitialize + self.metrics = [] + self.logstyle = logstyle # One of '%3.3f' or like '%3.3e' + + # Delete log if re-starting and log already exists + def reinit(self, item): + if os.path.exists('%s/%s.log' % (self.root, item)): + if self.reinitialize: + # Only print the removal mess + if 'sv' in item : + if not any('sv' in item for item in self.metrics): + print('Deleting singular value logs...') + else: + print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item))) + os.remove('%s/%s.log' % (self.root, item)) + + # Log in plaintext; this is designed for being read in MATLAB(sorry not sorry) + def log(self, itr, **kwargs): + for arg in kwargs: + if arg not in self.metrics: + if self.reinitialize: + self.reinit(arg) + self.metrics += [arg] + if self.logstyle == 'pickle': + print('Pickle not currently supported...') + # with open('%s/%s.log' % (self.root, arg), 'a') as f: + # pickle.dump(kwargs[arg], f) + elif self.logstyle == 'mat': + print('.mat logstyle not currently supported...') + else: + with open('%s/%s.log' % (self.root, arg), 'a') as f: + f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg])) + + +# Write some metadata to the logs directory +def write_metadata(logs_root, experiment_name, config, state_dict): + with open(('%s/%s/metalog.txt' % + (logs_root, experiment_name)), 'w') as writefile: + writefile.write('datetime: %s\n' % str(datetime.datetime.now())) + writefile.write('config: %s\n' % str(config)) + writefile.write('state: %s\n' %str(state_dict)) + + +""" +Very basic progress indicator to wrap an iterable in. + +Author: Jan Schlüter +Andy's adds: time elapsed in addition to ETA, makes it possible to add +estimated time to 1k iters instead of estimated time to completion. +""" +def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'): + """ + Returns a generator over `items`, printing the number and percentage of + items processed and the estimated remaining processing time before yielding + the next item. `total` gives the total number of items (required if `items` + has no length), and `min_delay` gives the minimum time in seconds between + subsequent prints. `desc` gives an optional prefix text (end with a space). + """ + total = total or len(items) + t_start = time.time() + t_last = 0 + for n, item in enumerate(items): + t_now = time.time() + if t_now - t_last > min_delay: + print("\r%s%d/%d (%6.2f%%)" % ( + desc, n+1, total, n / float(total) * 100), end=" ") + if n > 0: + + if displaytype == 's1k': # minutes/seconds for 1000 iters + next_1000 = n + (1000 - n%1000) + t_done = t_now - t_start + t_1k = t_done / n * next_1000 + outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60)) + print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") + else:# displaytype == 'eta': + t_done = t_now - t_start + t_total = t_done / n * total + outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60)) + print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") + + sys.stdout.flush() + t_last = t_now + yield item + t_total = time.time() - t_start + print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) + + divmod(t_total, 60))) + + +# Sample function for use with inception metrics +def sample(G, z_, y_, config): + with torch.no_grad(): + z_.sample_() + y_.sample_() + if config['parallel']: + G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_))) + else: + G_z = G(z_, G.shared(y_)) + return G_z, y_ + + +# Sample function for sample sheets +def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel, + samples_root, experiment_name, folder_number, z_=None): + # Prepare sample directory + if not os.path.isdir('%s/%s' % (samples_root, experiment_name)): + os.mkdir('%s/%s' % (samples_root, experiment_name)) + if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)): + os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)) + # loop over total number of sheets + for i in range(num_classes // classes_per_sheet): + ims = [] + y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda') + for j in range(samples_per_class): + if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0): + z_.sample_() + else: + z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda') + with torch.no_grad(): + if parallel: + o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y))) + else: + o = G(z_[:classes_per_sheet], G.shared(y)) + + ims += [o.data.cpu()] + # This line should properly unroll the images + out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2], + ims[0].shape[3]).data.float().cpu() + #out_ims = torch.from_numpy(out_ims.numpy()) ### NOTE: xcliu for torchvision + # The path for the samples + image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name, + folder_number, i) + torchvision.utils.save_image(out_ims, image_filename, + nrow=samples_per_class, normalize=True) + + +# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) +def interp(x0, x1, num_midpoints): + lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) + return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) + + +# interp sheet function +# Supports full, class-wise and intra-class interpolation +def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel, + samples_root, experiment_name, folder_number, sheet_number=0, + fix_z=False, fix_y=False, device='cuda'): + # Prepare zs and ys + if fix_z: # If fix Z, only sample 1 z per row + zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device) + zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z) + else: + zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device), + torch.randn(num_per_sheet, 1, G.dim_z, device=device), + num_midpoints).view(-1, G.dim_z) + if fix_y: # If fix y, only sample 1 z per row + ys = sample_1hot(num_per_sheet, num_classes) + ys = G.shared(ys).view(num_per_sheet, 1, -1) + ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1) + else: + ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), + G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), + num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1) + # Run the net--note that we've already passed y through G.shared. + if G.fp16: + zs = zs.half() + with torch.no_grad(): + if parallel: + out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu() + else: + out_ims = G(zs, ys).data.cpu() + interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '') + image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name, + folder_number, interp_style, + sheet_number) + torchvision.utils.save_image(out_ims, image_filename, + nrow=num_midpoints + 2, normalize=True) + + +# Convenience debugging function to print out gradnorms and shape from each layer +# May need to rewrite this so we can actually see which parameter is which +def print_grad_norms(net): + gradsums = [[float(torch.norm(param.grad).item()), + float(torch.norm(param).item()), param.shape] + for param in net.parameters()] + order = np.argsort([item[0] for item in gradsums]) + print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0], + gradsums[item_index][1], + str(gradsums[item_index][2])) + for item_index in order]) + + +# Get singular values to log. This will use the state dict to find them +# and substitute underscores for dots. +def get_SVs(net, prefix): + d = net.state_dict() + return {('%s_%s' % (prefix, key)).replace('.', '_') : + float(d[key].item()) + for key in d if 'sv' in key} + + +# Name an experiment based on its config +def name_from_config(config): + name = '_'.join([ + item for item in [ + 'Big%s' % config['which_train_fn'], + config['dataset'], + config['model'] if config['model'] != 'BigGAN' else None, + 'seed%d' % config['seed'], + 'Gch%d' % config['G_ch'], + 'Dch%d' % config['D_ch'], + 'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None, + 'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None, + 'bs%d' % config['batch_size'], + 'Gfp16' if config['G_fp16'] else None, + 'Dfp16' if config['D_fp16'] else None, + 'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None, + 'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None, + 'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None, + 'Glr%2.1e' % config['G_lr'], + 'Dlr%2.1e' % config['D_lr'], + 'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None, + 'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None, + 'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None, + 'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None, + 'Gnl%s' % config['G_nl'], + 'Dnl%s' % config['D_nl'], + 'Ginit%s' % config['G_init'], + 'Dinit%s' % config['D_init'], + 'G%s' % config['G_param'] if config['G_param'] != 'SN' else None, + 'D%s' % config['D_param'] if config['D_param'] != 'SN' else None, + 'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None, + 'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None, + 'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None, + 'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None, + config['norm_style'] if config['norm_style'] != 'bn' else None, + 'cr' if config['cross_replica'] else None, + 'Gshared' if config['G_shared'] else None, + 'hier' if config['hier'] else None, + 'ema' if config['ema'] else None, + config['name_suffix'] if config['name_suffix'] else None, + ] + if item is not None]) + # dogball + if config['hashname']: + return hashname(name) + else: + return name + + +# A simple function to produce a unique experiment name from the animal hashes. +def hashname(name): + h = hash(name) + a = h % len(animal_hash.a) + h = h // len(animal_hash.a) + b = h % len(animal_hash.b) + h = h // len(animal_hash.c) + c = h % len(animal_hash.c) + return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c] + + +# Get GPU memory, -i is the index +def query_gpu(indices): + os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv') + + +# Convenience function to count the number of parameters in a module +def count_parameters(module): + print('Number of parameters: {}'.format( + sum([p.data.nelement() for p in module.parameters()]))) + + +# Convenience function to sample an index, not actually a 1-hot +def sample_1hot(batch_size, num_classes, device='cuda'): + return torch.randint(low=0, high=num_classes, size=(batch_size,), + device=device, dtype=torch.int64, requires_grad=False) + + +# A highly simplified convenience class for sampling from distributions +# One could also use PyTorch's inbuilt distributions package. +# Note that this class requires initialization to proceed as +# x = Distribution(torch.randn(size)) +# x.init_distribution(dist_type, **dist_kwargs) +# x = x.to(device,dtype) +# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 +class Distribution(torch.Tensor): + # Init the params of the distribution + def init_distribution(self, dist_type, **kwargs): + self.dist_type = dist_type + self.dist_kwargs = kwargs + if self.dist_type == 'normal': + self.mean, self.var = kwargs['mean'], kwargs['var'] + elif self.dist_type == 'categorical': + self.num_categories = kwargs['num_categories'] + + def sample_(self): + if self.dist_type == 'normal': + self.normal_(self.mean, self.var) + elif self.dist_type == 'categorical': + self.random_(0, self.num_categories) + # return self.variable + + # Silly hack: overwrite the to() method to wrap the new object + # in a distribution as well + def to(self, *args, **kwargs): + new_obj = Distribution(self) + new_obj.init_distribution(self.dist_type, **self.dist_kwargs) + new_obj.data = super().to(*args, **kwargs) + return new_obj + + +# Convenience function to prepare a z and y vector +def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', + fp16=False,z_var=1.0): + z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) + z_.init_distribution('normal', mean=0, var=z_var) + z_ = z_.to(device,torch.float16 if fp16 else torch.float32) + + if fp16: + z_ = z_.half() + + y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) + y_.init_distribution('categorical',num_categories=nclasses) + y_ = y_.to(device, torch.int64) + return z_, y_ + + +def initiate_standing_stats(net): + for module in net.modules(): + if hasattr(module, 'accumulate_standing'): + module.reset_stats() + module.accumulate_standing = True + + +def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): + initiate_standing_stats(net) + net.train() + for i in range(num_accumulations): + with torch.no_grad(): + z.normal_() + y.random_(0, nclasses) + x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn + # Set to eval mode + net.eval() + + +# This version of Adam keeps an fp32 copy of the parameters and +# does all of the parameter updates in fp32, while still doing the +# forwards and backwards passes using fp16 (i.e. fp16 copies of the +# parameters and fp16 activations). +# +# Note that this calls .float().cuda() on the params. +import math +from torch.optim.optimizer import Optimizer +class Adam16(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + params = list(params) + super(Adam16, self).__init__(params, defaults) + + # Safety modification to make sure we floatify our state + def load_state_dict(self, state_dict): + super(Adam16, self).load_state_dict(state_dict) + for group in self.param_groups: + for p in group['params']: + self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float() + self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float() + self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float() + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data.float() + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + # Fp32 copy of the weights + state['fp32_p'] = p.data.float() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], state['fp32_p']) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['fp32_p'].addcdiv_(-step_size, exp_avg, denom) + p.data = state['fp32_p'].half() + + return loss diff --git a/text2image/BigGAN_utils/weights/README.md b/text2image/BigGAN_utils/weights/README.md new file mode 100644 index 0000000..4440e11 --- /dev/null +++ b/text2image/BigGAN_utils/weights/README.md @@ -0,0 +1,2 @@ +Download pre-trained weights from +https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing diff --git a/text2image/DiffAugment_pytorch.py b/text2image/DiffAugment_pytorch.py new file mode 100644 index 0000000..fa90feb --- /dev/null +++ b/text2image/DiffAugment_pytorch.py @@ -0,0 +1,102 @@ +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 + +import torch +import torch.nn.functional as F +import numpy as np + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): ### ratio: org: 0.125 + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() + return x + +def rand_resize(x, min_ratio=0.8, max_ratio=1.2): ### ratio: org: 0.125 + resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio + resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear') + org_size = x.shape[3] + #print('ORG:', x.shape) + #print('RESIZED:', resized_img.shape) + if int(resize_ratio*x.shape[3]) < x.shape[3]: + left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2. + left_pad = int(left_pad) + right_pad = x.shape[3] - left_pad - resized_img.shape[3] + #print('PAD:', left_pad, right_pad) + x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.) + #print('SMALL:', x.shape) + else: + left = (int(resize_ratio*x.shape[3])-x.shape[3])/2. + left = int(left) + #print('LEFT:', left) + x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])] + #print('LARGE:', x.shape) + assert x.shape[2] == org_size + assert x.shape[3] == org_size + + return x + + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'resize': [rand_resize], + 'cutout': [rand_cutout], +} diff --git a/text2image/LICENSE b/text2image/LICENSE new file mode 100644 index 0000000..04205d8 --- /dev/null +++ b/text2image/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 gnobitab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/text2image/README.md b/text2image/README.md new file mode 100644 index 0000000..c56644a --- /dev/null +++ b/text2image/README.md @@ -0,0 +1,65 @@ +# FuseDream + +This repo contains code for our paper ([paper link](https://arxiv.org/abs/2112.01573)): + +**FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimization** + +by *Xingchao Liu, Chengyue Gong, Lemeng Wu, Shujian Zhang, Hao Su and Qiang Liu* from UCSD and UT Austin. + +![FuseDream](./imgs/header_img.png?raw=true "FuseDream") + +## Introduction +FuseDream uses pre-trained GANs (we support BigGAN-256 and BigGAN-512 for now) and CLIP to achieve high-fidelity text-to-image generation. + +## Requirements +Please use `pip` or `conda` to install the following packages: +`PyTorch==1.7.1, torchvision==0.8.2, lpips==0.1.4` and also the requirements from [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch). + +## Getting Started + +We transformed the pre-trained weights of BigGAN from TFHub to PyTorch. To save your time, you can download the transformed BigGAN checkpoints from: + +https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing + +Put the checkpoints into `./BigGAN_utils/weights/` + +Run the following command to generate images from text query: + +`python fusedream_generator.py --text 'YOUR TEXT' --seed YOUR_SEED` + +For example, to get an image of a blue dog: + +`python fusedream_generator.py --text 'A photo of a blue dog.' --seed 1234` + +The generated image will be stored in `./samples` + +## Colab Notebook + +For a quick test of *FuseDream*, we provide Colab notebooks for [*FuseDream*(Single Image)](https://colab.research.google.com/drive/17qkzkoQQtzDRFaSCJQzIaNj88xjO9Rm9?usp=sharing) and *FuseDream-Composition*(TODO). Have fun! + +## Citations +If you use the code, please cite: + +```BibTex +@inproceedings{ +brock2018large, +title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis}, +author={Andrew Brock and Jeff Donahue and Karen Simonyan}, +booktitle={International Conference on Learning Representations}, +year={2019}, +url={https://openreview.net/forum?id=B1xsqj09Fm}, +} +``` + +and +```BibTex +@misc{ +liu2021fusedream, +title={FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimization}, +author={Xingchao Liu and Chengyue Gong and Lemeng Wu and Shujian Zhang and Hao Su and Qiang Liu}, +year={2021}, +eprint={2112.01573}, +archivePrefix={arXiv}, +primaryClass={cs.CV} +} +``` diff --git a/text2image/__init__.py b/text2image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text2image/fusedream_generator.py b/text2image/fusedream_generator.py new file mode 100644 index 0000000..8cf6091 --- /dev/null +++ b/text2image/fusedream_generator.py @@ -0,0 +1,35 @@ +import torch +from tqdm import tqdm +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import BigGAN_utils.utils as utils +import torch.nn.functional as F +from DiffAugment_pytorch import DiffAugment +import numpy as np +from fusedream_utils import FuseDreamBaseGenerator, get_G, save_image + +parser = utils.prepare_parser() +parser = utils.add_sample_parser(parser) +args = parser.parse_args() + +INIT_ITERS = 1000 +OPT_ITERS = 1000 + +utils.seed_rng(args.seed) + +sentence = args.text + +print('Generating:', sentence) +G, config = get_G(512) # Choose from 256 and 512 +generator = FuseDreamBaseGenerator(G, config, 10) +z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=5) + +z_cllt_save = torch.cat(z_cllt).cpu().numpy() +y_cllt_save = torch.cat(y_cllt).cpu().numpy() +img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=True, augment=True, opt_iters=OPT_ITERS, optimize_y=True) +score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20) +print('AugCLIP score:', score) +import os +if not os.path.exists('./samples'): + os.mkdir('./samples') +save_image(img, 'samples/fusedream_%s_seed_%d_score_%.4f.png'%(sentence, args.seed, score)) + diff --git a/text2image/fusedream_utils.py b/text2image/fusedream_utils.py new file mode 100644 index 0000000..ade612b --- /dev/null +++ b/text2image/fusedream_utils.py @@ -0,0 +1,308 @@ +import torch +from tqdm import tqdm +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import torchvision +import BigGAN_utils.utils as utils +import clip +import torch.nn.functional as F +from DiffAugment_pytorch import DiffAugment +import numpy as np +import lpips +import os +current_path = os.path.dirname(__file__) + +LATENT_NOISE = 0.01 +Z_THRES = 2.0 +POLICY = 'color,translation,resize,cutout' +TEST_POLICY = 'color,translation,resize,cutout' +mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() +std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() + +def AugmentLoss(img, clip_model, text, replicate=10, interp_mode='bilinear', policy=POLICY): + + clip_c = clip_model.logit_scale.exp() + img_aug = DiffAugment(img.repeat(replicate, 1, 1, 1), policy=policy) + img_aug = (img_aug+1.)/2. + img_aug = F.interpolate(img_aug, size=224, mode=interp_mode) + img_aug.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = clip_model(img_aug, text) + logits_per_image = logits_per_image / clip_c + concept_loss = (-1.) * logits_per_image + + return concept_loss.mean(dim=0, keepdim=False) + +def NaiveSemanticLoss(img, clip_model, text, interp_mode='bilinear'): + + clip_c = clip_model.logit_scale.exp() + img = (img+1.)/2. + img = F.interpolate(img, size=224, mode=interp_mode) + img.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = clip_model(img, text) + logits_per_image = logits_per_image / clip_c + concept_loss = (-1.) * logits_per_image + + return concept_loss.mean(dim=0, keepdim=False) + +def get_gaussian_mask(size=256): + x, y = np.meshgrid(np.linspace(-1,1, size), np.linspace(-1,1,size)) + dst = np.sqrt(x*x+y*y) + + # Intializing sigma and muu + sigma = 1 + muu = 0.000 + + # Calculating Gaussian array + gauss = np.exp(-( (dst-muu)**2 / ( 2.0 * sigma**2 ) ) ) + + return gauss + +def save_image(img, path, n_per_row=1): + with torch.no_grad(): + torchvision.utils.save_image( + torch.from_numpy(img.cpu().numpy()), ##hack, to turn Distribution back to tensor + path, + nrow=n_per_row, + normalize=True, + ) + +def get_G(resolution=256): + if resolution == 256: + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + + # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh. + config["resolution"] = utils.imsize_dict["I128_hdf5"] + config["n_classes"] = utils.nclass_dict["I128_hdf5"] + config["G_activation"] = utils.activation_dict["inplace_relu"] + config["D_activation"] = utils.activation_dict["inplace_relu"] + config["G_attn"] = "128" + config["D_attn"] = "128" + config["G_ch"] = 96 + config["D_ch"] = 96 + config["hier"] = True + config["dim_z"] = 140 + config["shared_dim"] = 128 + config["G_shared"] = True + config = utils.update_config_roots(config) + config["skip_init"] = True + config["no_optim"] = True + config["device"] = "cuda" + config["resolution"] = 256 + + # Set up cudnn.benchmark for free speed. + torch.backends.cudnn.benchmark = True + + # Import the model. + model = __import__(config["model"]) + G = model.Generator(**config).to(config["device"]) + utils.count_parameters(G) + + # Load weights. + weights_path = f"{current_path}/BigGAN_utils/weights/biggan-256.pth" # Change this. + G.load_state_dict(torch.load(weights_path), strict=False) + elif resolution == 512: + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + + # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs128x8.sh. + config["resolution"] = 512 + config["n_classes"] = utils.nclass_dict["I128_hdf5"] + config["G_activation"] = utils.activation_dict["inplace_relu"] + config["D_activation"] = utils.activation_dict["inplace_relu"] + config["G_attn"] = "64" + config["D_attn"] = "64" + config["G_ch"] = 96 + config["D_ch"] = 64 + config["hier"] = True + config["dim_z"] = 128 + config["shared_dim"] = 128 + config["G_shared"] = True + config = utils.update_config_roots(config) + config["skip_init"] = True + config["no_optim"] = True + config["device"] = "cuda" + + # Set up cudnn.benchmark for free speed. + torch.backends.cudnn.benchmark = True + + # Import the model. + model = __import__(config["model"]) + #print(config["model"]) + G = model.Generator(**config).to(config["device"]) + utils.count_parameters(G) + #print('G parameters:') + #for p, m in G.named_parameters(): + # print(p) + # Load weights. + weights_path = f"{current_path}/BigGAN_utils/weights/biggan-512.pth" # Change this. + G.load_state_dict(torch.load(weights_path), strict=False) + + return G, config + +class FuseDreamBaseGenerator(): + def __init__(self, G, G_config, G_batch_size=10, clip_mode="ViT-B/32", interp_mode='bilinear'): + + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.G = G + self.clip_model, _ = clip.load(clip_mode, device=device) + + (self.z_, self.y_) = utils.prepare_z_y( + G_batch_size, + self.G.dim_z, + G_config["n_classes"], + device=G_config["device"], + fp16=G_config["G_fp16"], + z_var=G_config["z_var"], + ) + + self.G.eval() + + for p in self.G.parameters(): + p.requires_grad = False + for p in self.clip_model.parameters(): + p.requires_grad = False + + self.interp_mode = interp_mode + + def generate_basis(self, text, init_iters=500, num_basis=5): + text_tok = clip.tokenize([text]).to(self.device) + clip_c = self.clip_model.logit_scale.exp() + + z_init_cllt = [] + y_init_cllt = [] + z_init = None + y_init = None + score_init = None + with torch.no_grad(): + for i in tqdm(range(init_iters)): + self.z_.sample_() + self.y_.sample_() + + self.z_.data = torch.clamp(self.z_.data.detach().clone(), min=-Z_THRES, max=Z_THRES) + + image_tensors = self.G(self.z_, self.G.shared(self.y_)) + image_tensors = (image_tensors+1.) / 2. + image_tensors = F.interpolate(image_tensors, size=224, mode=self.interp_mode) + image_tensors.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = self.clip_model(image_tensors, text_tok) + logits_per_image = logits_per_image/clip_c + if z_init is None: + z_init = self.z_.data.detach().clone() + y_init = self.y_.data.detach().clone() + score_init = logits_per_image.squeeze() + else: + z_init = torch.cat([z_init, self.z_.data.detach().clone()], dim=0) + y_init = torch.cat([y_init, self.y_.data.detach().clone()], dim=0) + score_init = torch.cat([score_init, logits_per_image.squeeze()]) + + sorted, indices = torch.sort(score_init, descending=True) + z_init = z_init[indices] + y_init = y_init[indices] + score_init = score_init[indices] + z_init = z_init[:num_basis] + y_init = y_init[:num_basis] + score_init = score_init[:num_basis] + + #save_image(self.G(z_init, self.G.shared(y_init)), 'samples/init_%s.png'%text, 1) + + z_init_cllt.append(z_init.detach().clone()) + y_init_cllt.append(self.G.shared(y_init.detach().clone())) + + return z_init_cllt, y_init_cllt + + + def optimize_clip_score(self, z_init_cllt, y_init_cllt, text, latent_noise=False, augment=True, opt_iters=500, optimize_y=False): + + text_tok = clip.tokenize([text]).to(self.device) + clip_c = self.clip_model.logit_scale.exp() + + z_init_ans = torch.stack(z_init_cllt) + y_init_ans = torch.stack(y_init_cllt) + z_init_ans = z_init_ans.view(-1, z_init_ans.shape[-1]) + y_init_ans = y_init_ans.view(-1, y_init_ans.shape[-1]) + + w_z = torch.randn((z_init_ans.shape[0], z_init_ans.shape[1])).to(self.device) + w_y = torch.randn((y_init_ans.shape[0], y_init_ans.shape[1])).to(self.device) + w_z.requires_grad = True + w_y.requires_grad = True + + opt_y = torch.zeros(y_init_ans.shape).to(self.device) + opt_y.data = y_init_ans.data.detach().clone() + opt_z = torch.zeros(z_init_ans.shape).to(self.device) + opt_z.data = z_init_ans.data.detach().clone() + opt_z.requires_grad = True + + if not optimize_y: + optimizer = torch.optim.Adam([w_z, w_y, opt_z], lr=5e-3, weight_decay=0.0) + else: + opt_y.requires_grad = True + optimizer = torch.optim.Adam([w_z, w_y,opt_y,opt_z], lr=5e-3, weight_decay=0.0) + + for i in tqdm(range(opt_iters)): + #print(w_z.shape, w_y.shape) + optimizer.zero_grad() + + if not latent_noise: + s_z = torch.softmax(w_z, dim=0) + s_y = torch.softmax(w_y, dim=0) + #print(s_z) + + cur_z = s_z * opt_z + cur_y = s_y * opt_y + cur_z = cur_z.sum(dim=0, keepdim=True) + cur_y = cur_y.sum(dim=0, keepdim=True) + + image_tensors = self.G(cur_z, cur_y) + else: + s_z = torch.softmax(w_z, dim=0) + s_y = torch.softmax(w_y, dim=0) + + cur_z = s_z * opt_z + cur_y = s_y * opt_y + cur_z = cur_z.sum(dim=0, keepdim=True) + cur_y = cur_y.sum(dim=0, keepdim=True) + cur_z_aug = cur_z + torch.randn(cur_z.shape).to(cur_z.device) * LATENT_NOISE + cur_y_aug = cur_y + torch.randn(cur_y.shape).to(cur_y.device) * LATENT_NOISE + + image_tensors = self.G(cur_z_aug, cur_y_aug) + + loss = 0.0 + for j in range(image_tensors.shape[0]): + if augment: + loss = loss + AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode) + else: + loss = loss + NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok) + + loss.backward() + optimizer.step() + + opt_z.data = torch.clamp(opt_z.data.detach().clone(), min=-Z_THRES, max=Z_THRES) + + z_init_ans = cur_z.detach().clone() + y_init_ans = cur_y.detach().clone() + + # save_image(self.G(z_init_ans, y_init_ans), '/home/zhaojh/workspace/computer_vision/opt_%s.png'%text, 1) + return self.G(z_init_ans, y_init_ans), z_init_ans, y_init_ans + + def measureAugCLIP(self, z, y, text, augment=False, num_samples=20): + text_tok = clip.tokenize([text]).to(self.device) + avg_loss = 0.0 + for itr in range(num_samples): + image_tensors = self.G(z, y) + + for j in range(image_tensors.shape[0]): + if augment: + loss = AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode, policy=TEST_POLICY) + else: + loss = NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok) + avg_loss += loss.item() + + avg_loss /= num_samples + return avg_loss * (-1.) + diff --git a/text2image/run_text2img.py b/text2image/run_text2img.py new file mode 100644 index 0000000..3a160b6 --- /dev/null +++ b/text2image/run_text2img.py @@ -0,0 +1,30 @@ +import numpy as np +from logzero import logger +import torch +from torchvision.utils import make_grid +from BigGAN_utils import utils +from fusedream_utils import FuseDreamBaseGenerator, get_G +from PIL import Image + +INIT_ITERS = 1000 +OPT_ITERS = 1000 +NUM_BASIS = 5 +MODEL = "biggan-512" +SEED = 1884 + +def text2image(sentence:str): + utils.seed_rng(SEED) + logger.info(f'Generating: {sentence}') + G, config = get_G(512) + generator = FuseDreamBaseGenerator(G, config, 10) + z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=NUM_BASIS) + + z_cllt_save = torch.cat(z_cllt).cpu().numpy() + y_cllt_save = torch.cat(y_cllt).cpu().numpy() + img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=False, augment=True, opt_iters=OPT_ITERS, optimize_y=True) + ## Set latent_noise = True yields slightly higher AugCLIP score, but slightly lower image quality. We set it to False for dogs. + score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20) + grid = make_grid(img, nrow=1, normalize=True) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + return im \ No newline at end of file