commit c8b8388df52607802e6e3a9e3c5aa07ac9773fbc Author: zhaojinghao Date: Wed Aug 3 10:16:48 2022 +0800 initial commit diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..be9f28c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.7.13-slim + +WORKDIR /app +ADD . /app/ + +RUN pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir +RUN pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 --no-cache-dir +RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir +RUN mim install mmcv-full +RUN pip install -e ./text2image/CLIP/ +RUN pip install -e ./backbone/mmpose/ +RUN pip install -e ./ocr/tr/ +RUN rm -rf ./text2image/CLIP/ +RUN rm -rf ./ocr/tr/ +RUN rm -rf ./backbone/mmpose/ + +CMD ["python3", "run.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1c03ad --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +## 机器视觉模块 +### backbone: 关键点检测 +### detection: 目标检测 +### mnist: mnist手写体识别 +### ocr: 光学字符识别 +### segmentation: (光伏)图像分割 +### text2image: (文本)图像生成 \ No newline at end of file diff --git a/backbone/associative_embedding_hrnet_w32_coco_512x512.py b/backbone/associative_embedding_hrnet_w32_coco_512x512.py new file mode 100644 index 0000000..4774d01 --- /dev/null +++ b/backbone/associative_embedding_hrnet_w32_coco_512x512.py @@ -0,0 +1,1134 @@ +checkpoint_config = dict(interval=50) +log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) +log_level = 'INFO' +load_from = None +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] +opencv_num_threads = 0 +mp_start_method = 'fork' +dataset_info = dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict(name='nose', id=0, color=[51, 153, 255], type='upper', swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]), + 3: + dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]), + 4: + dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]), + 5: + dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]), + 6: + dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]), + 10: + dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]), + 11: + dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]), + 12: + dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]), + 16: + dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]), + 17: + dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, 1.0, 1.2, + 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, + 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ]) +evaluation = dict(interval=50, metric='mAP', save_best='AP') +optimizer = dict(type='Adam', lr=0.0015) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[200, 260]) +total_epochs = 300 +channel_cfg = dict( + dataset_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) +data_cfg = dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False) +model = dict( + type='AssociativeEmbedding', + pretrained= + 'https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w32-36af842e.pth', + backbone=dict( + type='HRNet', + in_channels=3, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256)))), + keypoint_head=dict( + type='AESimpleHead', + in_channels=32, + num_joints=17, + num_deconv_layers=0, + tag_per_joint=True, + with_ae_loss=[True], + extra=dict(final_conv_kernel=1), + loss_keypoint=dict( + type='MultiLossFactory', + num_joints=17, + num_stages=1, + ae_loss_type='exp', + with_ae_loss=[True], + push_loss_factor=[0.001], + pull_loss_factor=[0.001], + with_heatmaps_loss=[True], + heatmaps_loss_factor=[1.0])), + train_cfg=dict(), + test_cfg=dict( + num_joints=17, + max_num_people=30, + scale_factor=[1], + with_heatmaps=[True], + with_ae=[True], + project2image=True, + align_corners=False, + nms_kernel=5, + nms_padding=2, + tag_per_joint=True, + detection_threshold=0.1, + tag_threshold=1, + use_detection_val=True, + ignore_too_much=False, + adjust=True, + refine=True, + flip_test=True)) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='BottomUpRandomAffine', + rot_factor=30, + scale_factor=[0.75, 1.5], + scale_type='short', + trans_factor=40), + dict(type='BottomUpRandomFlip', flip_prob=0.5), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='BottomUpGenerateTarget', sigma=2, max_num_people=30), + dict( + type='Collect', + keys=['img', 'joints', 'targets', 'masks'], + meta_keys=[]) +] +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) +] +data_root = 'data/coco' +data = dict( + workers_per_gpu=2, + train_dataloader=dict(samples_per_gpu=24), + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_train2017.json', + img_prefix='data/coco/train2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='BottomUpRandomAffine', + rot_factor=30, + scale_factor=[0.75, 1.5], + scale_type='short', + trans_factor=40), + dict(type='BottomUpRandomFlip', flip_prob=0.5), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='BottomUpGenerateTarget', sigma=2, max_num_people=30), + dict( + type='Collect', + keys=['img', 'joints', 'targets', 'masks'], + meta_keys=[]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ])), + val=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_val2017.json', + img_prefix='data/coco/val2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ])), + test=dict( + type='BottomUpCocoDataset', + ann_file='data/coco/annotations/person_keypoints_val2017.json', + img_prefix='data/coco/val2017/', + data_cfg=dict( + image_size=512, + base_size=256, + base_sigma=2, + heatmap_size=[128], + num_joints=17, + dataset_channel=[[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ], + num_scales=1, + scale_aware_sigma=False), + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='BottomUpGetImgSize', test_scale_factor=[1]), + dict( + type='BottomUpResizeAlign', + transforms=[ + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'aug_data', 'test_scale_factor', 'base_size', + 'center', 'scale', 'flip_index' + ]) + ], + dataset_info=dict( + dataset_name='coco', + paper_info=dict( + author= + 'Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/'), + keypoint_info=dict({ + 0: + dict( + name='nose', + id=0, + color=[51, 153, 255], + type='upper', + swap=''), + 1: + dict( + name='left_eye', + id=1, + color=[51, 153, 255], + type='upper', + swap='right_eye'), + 2: + dict( + name='right_eye', + id=2, + color=[51, 153, 255], + type='upper', + swap='left_eye'), + 3: + dict( + name='left_ear', + id=3, + color=[51, 153, 255], + type='upper', + swap='right_ear'), + 4: + dict( + name='right_ear', + id=4, + color=[51, 153, 255], + type='upper', + swap='left_ear'), + 5: + dict( + name='left_shoulder', + id=5, + color=[0, 255, 0], + type='upper', + swap='right_shoulder'), + 6: + dict( + name='right_shoulder', + id=6, + color=[255, 128, 0], + type='upper', + swap='left_shoulder'), + 7: + dict( + name='left_elbow', + id=7, + color=[0, 255, 0], + type='upper', + swap='right_elbow'), + 8: + dict( + name='right_elbow', + id=8, + color=[255, 128, 0], + type='upper', + swap='left_elbow'), + 9: + dict( + name='left_wrist', + id=9, + color=[0, 255, 0], + type='upper', + swap='right_wrist'), + 10: + dict( + name='right_wrist', + id=10, + color=[255, 128, 0], + type='upper', + swap='left_wrist'), + 11: + dict( + name='left_hip', + id=11, + color=[0, 255, 0], + type='lower', + swap='right_hip'), + 12: + dict( + name='right_hip', + id=12, + color=[255, 128, 0], + type='lower', + swap='left_hip'), + 13: + dict( + name='left_knee', + id=13, + color=[0, 255, 0], + type='lower', + swap='right_knee'), + 14: + dict( + name='right_knee', + id=14, + color=[255, 128, 0], + type='lower', + swap='left_knee'), + 15: + dict( + name='left_ankle', + id=15, + color=[0, 255, 0], + type='lower', + swap='right_ankle'), + 16: + dict( + name='right_ankle', + id=16, + color=[255, 128, 0], + type='lower', + swap='left_ankle') + }), + skeleton_info=dict({ + 0: + dict( + link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + 1: + dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + 2: + dict( + link=('right_ankle', 'right_knee'), + id=2, + color=[255, 128, 0]), + 3: + dict( + link=('right_knee', 'right_hip'), + id=3, + color=[255, 128, 0]), + 4: + dict( + link=('left_hip', 'right_hip'), id=4, color=[51, 153, + 255]), + 5: + dict( + link=('left_shoulder', 'left_hip'), + id=5, + color=[51, 153, 255]), + 6: + dict( + link=('right_shoulder', 'right_hip'), + id=6, + color=[51, 153, 255]), + 7: + dict( + link=('left_shoulder', 'right_shoulder'), + id=7, + color=[51, 153, 255]), + 8: + dict( + link=('left_shoulder', 'left_elbow'), + id=8, + color=[0, 255, 0]), + 9: + dict( + link=('right_shoulder', 'right_elbow'), + id=9, + color=[255, 128, 0]), + 10: + dict( + link=('left_elbow', 'left_wrist'), + id=10, + color=[0, 255, 0]), + 11: + dict( + link=('right_elbow', 'right_wrist'), + id=11, + color=[255, 128, 0]), + 12: + dict( + link=('left_eye', 'right_eye'), + id=12, + color=[51, 153, 255]), + 13: + dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), + 14: + dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), + 15: + dict( + link=('left_eye', 'left_ear'), id=15, color=[51, 153, + 255]), + 16: + dict( + link=('right_eye', 'right_ear'), + id=16, + color=[51, 153, 255]), + 17: + dict( + link=('left_ear', 'left_shoulder'), + id=17, + color=[51, 153, 255]), + 18: + dict( + link=('right_ear', 'right_shoulder'), + id=18, + color=[51, 153, 255]) + }), + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, + 1.0, 1.2, 1.2, 1.5, 1.5 + ], + sigmas=[ + 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + ]))) diff --git a/backbone/backbone_infer.py b/backbone/backbone_infer.py new file mode 100644 index 0000000..d4e2432 --- /dev/null +++ b/backbone/backbone_infer.py @@ -0,0 +1,13 @@ +from mmpose.apis import inference_bottom_up_pose_model, vis_pose_result +import cv2 +import os + +def run_backbone_infer(img_ndarr, pose_model): + # test a single image + pose_results, _ = inference_bottom_up_pose_model(pose_model, img_ndarr) + rst = vis_pose_result(pose_model, img_ndarr, pose_results, dataset='TopDownCocoWholeBodyDataset', thickness=2, radius=8, bbox_color='white') + return pose_results, rst + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/detection/detection.py b/detection/detection.py new file mode 100644 index 0000000..33750f8 --- /dev/null +++ b/detection/detection.py @@ -0,0 +1,22 @@ +import torch +import matplotlib.pyplot as plt + +def detector(img, model): + """_summary_ + + Args: + img (str or numpy.ndarray): 图片路径或者像素矩阵 + model (_type_): 预加载的模型 + + Returns: + rtn(numpy.ndarray): 渲染后的图片像素点 + pred(numpy.ndarray): 检测而出的目标的坐标点、置信度和类别,shape=[n, 6] + """ + result = model(img) + + return result.render()[0], result.pred[0].cpu().numpy() + +if __name__ == '__main__': + # model = torch.hub.load('/home/zhaojh/workspace/git_space/yolov5/', 'yolov5x', source='local', pretrained=True) + pass + \ No newline at end of file diff --git a/mnist/MNIST_cnn.py b/mnist/MNIST_cnn.py new file mode 100644 index 0000000..dbf0803 --- /dev/null +++ b/mnist/MNIST_cnn.py @@ -0,0 +1,71 @@ +import os.path + +import numpy as np +import tensorflow as tf +physical_devices = tf.config.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(physical_devices[0], True) + +num_classes = 10 +input_shape = (28, 28, 1) + +def pre_process(): + + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = tf.keras.utils.to_categorical(y_train, num_classes) + y_test = tf.keras.utils.to_categorical(y_test, num_classes) + + return x_train, y_train, x_test, y_test + + +def createModel(neure, kernel_size, pool_size, activation): + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(input_shape)), + tf.keras.layers.Conv2D(neure, kernel_size=(kernel_size, kernel_size), activation=activation), + tf.keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size)), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + + model.summary() + + return model + +def trainModel(model, x_train, y_train, epochs, loss): + batch_size = 128 + epochs = epochs + + model.compile(loss=loss, optimizer="adam", metrics=["accuracy"]) + + model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) + + return model + +# 模型预测 +def predictModel(x_test, model): + predicted_data = model.predict(x_test) + return predicted_data + +def train(neure, kernel_size, pool_size, activation, epochs, loss): + x_train, y_train, x_test, y_test = pre_process() + model = createModel(neure, kernel_size, pool_size, activation) + model = trainModel(model, x_train, y_train, epochs, loss) + score = model.evaluate(x_test, y_test, verbose=0) + accuracy = score[1] + + model.save(os.path.abspath("./appweb/self_model/mnist_cnn.h5")) + + return accuracy \ No newline at end of file diff --git a/ocr/ocr.py b/ocr/ocr.py new file mode 100644 index 0000000..f9ac5b0 --- /dev/null +++ b/ocr/ocr.py @@ -0,0 +1,22 @@ +import tr +import cv2 + +def run_tr(image): + """_summary_ + + Args: + image (np.ndarray): 像素矩阵 + + Returns: + list: 字组成的列表 + """ + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rst = tr.run(gray_image) + if len(rst) == 0: + return [] + return [x[1] for x in rst] + + +if __name__ == '__main__': + pass + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ebeb26a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +Flask==2.1.0 +h5py +ftfy==6.1.1 +logzero==1.7.0 +lpips==0.1.4 +numpy==1.21.6 +pandas==1.3.5 +parse==1.19.0 +Pillow==9.2.0 +scipy==1.4.1 +six==1.15.0 +tqdm==4.64.0 +opencv-python==4.6.0.66 +seaborn==0.11.2 +segmentation-models-pytorch==0.2.1 +albumentations==1.2.1 +openmim==0.2.0 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..f48a0ab --- /dev/null +++ b/run.py @@ -0,0 +1,143 @@ +# -*-coding:utf-8-*- +import os +import sys +from logzero import logger +current_path = os.path.dirname(__file__) +logger.info(current_path) +sys.path.append(f"{current_path}/text2image/") +sys.path.append(f"{current_path}/text2image/BigGAN_utils/") +import json +import base64 +from flask import Flask, request, make_response +import cv2 +# from io import BytesIO +# import torch +from mmpose.apis import init_pose_model +# from text2image.run_text2img import text2image +# from detection.detection import detector +# from segmentation.segment_pred import run_seg +# from ocr.ocr import run_tr +from backbone.backbone_infer import run_backbone_infer + +DEVICE = 'cpu' + +# model_5x = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5x', source='local', pretrained=True) +# model_5s = torch.hub.load(f'{current_path}/detection/yolov5/','yolov5s', source='local', pretrained=True) +# model_seg = torch.load(f'{current_path}/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) +pose_config_file = f'{current_path}/backbone/associative_embedding_hrnet_w32_coco_512x512.py' +pose_ckpt_file = f'{current_path}/backbone/models/hrnet_w32_coco_512x512-bcb8c247_20200816.pth' +pose_model = init_pose_model(pose_config_file, pose_ckpt_file, device='cpu') # or device='cuda:0' + +app=Flask(__name__) + +# @app.route('/text2image/',methods=["POST"]) +# def run_text2img(): +# if request.method == "POST": +# text = request.form.get('text') +# logger.info(f"{text}") +# img = text2image(text) +# output_buffer = BytesIO() +# img.save(output_buffer, format='png') +# byte_data = output_buffer.getvalue() +# b64_code = base64.b64encode(byte_data).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return resp +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + + +# @app.route('/detection/', methods=["POST"]) +# def run_detection(): +# if request.method == "POST": +# img = request.files.get('image') +# model_type = request.form.get('model_type') +# try: +# img = cv2.imread(img) +# except: +# resp = make_response() +# resp.status_code = 406 +# return resp +# if model_type.lower().strip() == 'yolov5x': +# rst, _ = detector(img, model_5x) +# else: +# rst, _ = detector(img, model_5s) +# logger.info(rst.shape) +# img_str = cv2.imencode('.png', rst)[1].tobytes() +# b64_code = base64.b64encode(img_str).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return b64_code +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + +# @app.route('/ocr/', methods=["POST"]) +# def run_ocr(): +# resp = make_response() +# if request.method == "POST": +# img = request.files.get('image') +# try: +# img = cv2.imread(img) +# except: +# resp.status_code = 406 +# return resp +# text = run_tr(img) +# resp.status_code = 200 +# resp.data = json.dumps({'result':text}) +# return resp +# else: +# resp.status_code=405 +# return resp + + +# @app.route('/segmentation/', methods=["POST"]) +# def run_segmentation(): +# if request.method == "POST": +# img_upload = request.files.get('image') +# try: +# img = cv2.imread(img_upload) +# except: +# resp = make_response() +# resp.status_code = 406 +# return resp +# result = run_seg(img, model_seg) +# img_str = cv2.imencode('.png', result)[1].tobytes() +# b64_code = base64.b64encode(img_str).decode('utf-8') +# resp = make_response(b64_code) +# resp.status_code = 200 +# return resp +# else: +# resp = make_response() +# resp.status_code=405 +# return resp + +@app.route('/backbone/', methods=["POST"]) +def run_backbone(): + if request.method == "POST": + img_upload = request.files.get('image') + try: + img = cv2.imread(img_upload) + except: + resp = make_response() + resp.status_code = 406 + return resp + pose, result = run_backbone_infer(img, pose_model) + img_str = cv2.imencode('.png', result)[1].tobytes() + b64_code = base64.b64encode(img_str).decode('utf-8') + resp = make_response(b64_code) + resp.status_code = 200 + return resp + else: + resp = make_response() + resp.status_code=405 + return resp + + +if __name__ == '__main__': + img = cv2.imread('./1.jpg') + pose, rst = run_backbone_infer(img, pose_model) + cv2.imwrite('./1_bb.jpg', rst) \ No newline at end of file diff --git a/segmentation/segment_pred.py b/segmentation/segment_pred.py new file mode 100644 index 0000000..075b9f6 --- /dev/null +++ b/segmentation/segment_pred.py @@ -0,0 +1,66 @@ +from asyncio.log import logger +import os +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import albumentations as albu +import torch +import segmentation_models_pytorch as smp +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as BaseDataset +from PIL import Image + +DEVICE = 'cpu' +ENCODER = 'se_resnext50_32x4d' +ENCODER_WEIGHTS = 'imagenet' + +# --------------------------------------------------------------- +### 加载数据 + +def get_validation_augmentation(): + """调整图像使得图片的分辨率长宽能被32整除""" + test_transform = [ + albu.PadIfNeeded(256, 256) + ] + return albu.Compose(test_transform) + + +def to_tensor(x, **kwargs): + return x.transpose(2, 0, 1).astype('float32') + + +def get_preprocessing(preprocessing_fn): + """进行图像预处理操作 + + Args: + preprocessing_fn (callbale): 数据规范化的函数 + (针对每种预训练的神经网络) + Return: + transform: albumentations.Compose + """ + + _transform = [ + albu.Lambda(image=preprocessing_fn), + albu.Lambda(image=to_tensor), + ] + return albu.Compose(_transform) + + +def run_seg(img, best_model): + # 测试集 + preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) + augmentator = get_validation_augmentation() + preprocessor = get_preprocessing(preprocessing_fn) + # --------------------------------------------------------------- + img = augmentator(image=img)['image'] + img = preprocessor(image=img)['image'] + # 加载最佳模型 + x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0) + pr_mask = best_model.predict(x_tensor) + pr_mask = (pr_mask.squeeze().cpu().numpy().round()) + return (pr_mask - 1) * (-220) + + +if __name__ == '__main__': + best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) + input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp') \ No newline at end of file diff --git a/text2image/BigGAN_utils/BigGAN.py b/text2image/BigGAN_utils/BigGAN.py new file mode 100644 index 0000000..3367ba0 --- /dev/null +++ b/text2image/BigGAN_utils/BigGAN.py @@ -0,0 +1,484 @@ +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +import layers +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d + + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample' : [True] * 7, + 'resolution' : [8, 16, 32, 64, 128, 256, 512], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,10)}} + arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample' : [True] * 6, + 'resolution' : [8, 16, 32, 64, 128, 256], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,9)}} + arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample' : [True] * 5, + 'resolution' : [8, 16, 32, 64, 128], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,8)}} + arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2]], + 'upsample' : [True] * 4, + 'resolution' : [8, 16, 32, 64], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,7)}} + arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], + 'out_channels' : [ch * item for item in [4, 4, 4]], + 'upsample' : [True] * 3, + 'resolution' : [8, 16, 32], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,6)}} + + return arch + +class Generator(nn.Module): + def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + # If using hierarchical latents, adjust z + if self.hier: + # Number of places z slots into + self.num_slots = len(self.arch['in_channels']) + 1 + self.z_chunk_size = (self.dim_z // self.num_slots) + # Recalculate latent dimensionality for even splitting into chunks + self.dim_z = self.z_chunk_size * self.num_slots + else: + self.num_slots = 1 + self.z_chunk_size = 0 + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.z_chunk_size if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z // self.num_slots, + self.arch['in_channels'][0] * (self.bottom_width **2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] else None))]] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print(self.arch['resolution'], self.arch['attention']) + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward(self, z, y, w_y=None): + if w_y is not None: + s_y = torch.softmax(w_y, dim=1) + + cur_y = s_y * y + y = cur_y.sum(dim=1, keepdim=False) + + # If hierarchical, concatenate zs and ys + if self.hier: + zs = torch.split(z, self.z_chunk_size, 1) + z = zs[0] + ys = [torch.cat([y, item], 1) for item in zs[1:]] + else: + ys = [y] * len(self.blocks) + + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, ys[index]) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward_org(self, z, y): + # If hierarchical, concatenate zs and ys + if self.hier: + zs = torch.split(z, self.z_chunk_size, 1) + z = zs[0] + ys = [torch.cat([y, item], 1) for item in zs[1:]] + else: + ys = [y] * len(self.blocks) + + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, ys[index]) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample' : [True] * 6 + [False], + 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample' : [True] * 5 + [False], + 'resolution' : [64, 32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], + 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample' : [True] * 4 + [False], + 'resolution' : [32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,7)}} + arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], + 'out_channels' : [item * ch for item in [4, 4, 4, 4]], + 'downsample' : [True, True, False, False], + 'resolution' : [16, 16, 16, 16], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,6)}} + return arch + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=128, + D_kernel_size=3, D_attn='64', n_classes=1000, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, + SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN', **kwargs): + super(Discriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Set up optimizer + self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps + if D_mixed_precision: + print('Using fp16 adam in D...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out diff --git a/text2image/BigGAN_utils/BigGANdeep.py b/text2image/BigGAN_utils/BigGANdeep.py new file mode 100644 index 0000000..95763c3 --- /dev/null +++ b/text2image/BigGAN_utils/BigGANdeep.py @@ -0,0 +1,534 @@ +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +import layers +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d + +# BigGAN-deep: uses a different resblock and pattern + + +# Architectures for G +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. + +# Channel ratio is the ratio of +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=layers.bn, activation=None, + upsample=None, channel_ratio=4): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.hidden_channels = self.in_channels // channel_ratio + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, + kernel_size=1, padding=0) + self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(self.in_channels) + self.bn2 = self.which_bn(self.hidden_channels) + self.bn3 = self.which_bn(self.hidden_channels) + self.bn4 = self.which_bn(self.hidden_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + # Project down to channel ratio + h = self.conv1(self.activation(self.bn1(x, y))) + # Apply next BN-ReLU + h = self.activation(self.bn2(h, y)) + # Drop channels in x if necessary + if self.in_channels != self.out_channels: + x = x[:, :self.out_channels] + # Upsample both h and x at this point + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + # 3x3 convs + h = self.conv2(h) + h = self.conv3(self.activation(self.bn3(h, y))) + # Final 1x1 conv + h = self.conv4(self.activation(self.bn4(h, y))) + return h + x + +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample' : [True] * 6, + 'resolution' : [8, 16, 32, 64, 128, 256], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,9)}} + arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample' : [True] * 5, + 'resolution' : [8, 16, 32, 64, 128], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,8)}} + arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], + 'out_channels' : [ch * item for item in [16, 8, 4, 2]], + 'upsample' : [True] * 4, + 'resolution' : [8, 16, 32, 64], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,7)}} + arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], + 'out_channels' : [ch * item for item in [4, 4, 4]], + 'upsample' : [True] * 3, + 'resolution' : [8, 16, 32], + 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) + for i in range(3,6)}} + + return arch + +class Generator(nn.Module): + def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, + BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, + G_init='ortho', skip_init=False, no_optim=False, + G_param='SN', norm_style='bn', + **kwargs): + super(Generator, self).__init__() + # Channel width mulitplier + self.ch = G_ch + # Number of resblocks per stage + self.G_depth = G_depth + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial spatial dimensions + self.bottom_width = bottom_width + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=(self.shared_dim + self.dim_z if self.G_shared + else self.n_classes), + norm_style=self.norm_style, + eps=self.BN_eps) + + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + self.linear = self.which_linear(self.dim_z + self.shared_dim, self.arch['in_channels'][0] * (self.bottom_width **2)) + + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['in_channels'][index] if g_index==0 else self.arch['out_channels'][index], + which_conv=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=2) + if self.arch['upsample'][index] and g_index == (self.G_depth-1) else None))] + for g_index in range(self.G_depth)] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], 3)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self.init_weights() + + # Set up optimizer + # If this is an EMA copy, no need for an optim, so just return now + if no_optim: + return + self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps + if G_mixed_precision: + print('Using fp16 adam in G...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, + eps=self.adam_eps) + + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for G''s initialized parameters: %d' % self.param_count) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + # NOTE: The z vs y dichotomy here is for compatibility with not-y + def forward(self, z, y): + # If hierarchical, concatenate zs and ys + if self.hier: + z = torch.cat([y, z], 1) + y = z + # First linear layer + h = self.linear(z) + # Reshape + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + for block in blocklist: + h = block(h, y) + + # Apply batchnorm-relu-conv-tanh at output + return torch.tanh(self.output_layer(h)) + +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=layers.SNConv2d, wide=True, + preactivation=True, activation=None, downsample=None, + channel_ratio=4): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels // channel_ratio + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels, + kernel_size=1, padding=0) + self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels) + self.conv4 = self.which_conv(self.hidden_channels, self.out_channels, + kernel_size=1, padding=0) + + self.learnable_sc = True if (in_channels != out_channels) else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels - in_channels, + kernel_size=1, padding=0) + def shortcut(self, x): + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = torch.cat([x, self.conv_sc(x)], 1) + return x + + def forward(self, x): + # 1x1 bottleneck conv + h = self.conv1(F.relu(x)) + # 3x3 convs + h = self.conv2(self.activation(h)) + h = self.conv3(self.activation(h)) + # relu before downsample + h = self.activation(h) + # downsample + if self.downsample: + h = self.downsample(h) + # final 1x1 conv + h = self.conv4(h) + return h + self.shortcut(x) + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16]], + 'out_channels' : [item * ch for item in [2, 4, 8, 8, 16, 16]], + 'downsample' : [True] * 6 + [False], + 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[128] = {'in_channels' : [item * ch for item in [1, 2, 4, 8, 16]], + 'out_channels' : [item * ch for item in [2, 4, 8, 16, 16]], + 'downsample' : [True] * 5 + [False], + 'resolution' : [64, 32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,8)}} + arch[64] = {'in_channels' : [item * ch for item in [1, 2, 4, 8]], + 'out_channels' : [item * ch for item in [2, 4, 8, 16]], + 'downsample' : [True] * 4 + [False], + 'resolution' : [32, 16, 8, 4, 4], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,7)}} + arch[32] = {'in_channels' : [item * ch for item in [4, 4, 4]], + 'out_channels' : [item * ch for item in [4, 4, 4]], + 'downsample' : [True, True, False, False], + 'resolution' : [16, 16, 16, 16], + 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] + for i in range(2,6)}} + return arch + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=128, + D_kernel_size=3, D_attn='64', n_classes=1000, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, + SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='ortho', skip_init=False, D_param='SN', **kwargs): + super(Discriminator, self).__init__() + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # How many resblocks per stage? + self.D_depth = D_depth + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention)[resolution] + + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + + + # Prepare model + # Stem convolution + self.input_conv = self.which_conv(3, self.arch['in_channels'][0]) + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index] if d_index==0 else self.arch['out_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=True, + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index==0 else None)) + for d_index in range(self.D_depth)]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self.init_weights() + + # Set up optimizer + self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps + if D_mixed_precision: + print('Using fp16 adam in D...') + import utils + self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + else: + self.optim = optim.Adam(params=self.parameters(), lr=self.lr, + betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) + # LR scheduling, left here for forward compatibility + # self.lr_sched = {'itr' : 0}# if self.progressive else {} + # self.j = 0 + + # Initialize + def init_weights(self): + self.param_count = 0 + for module in self.modules(): + if (isinstance(module, nn.Conv2d) + or isinstance(module, nn.Linear) + or isinstance(module, nn.Embedding)): + if self.init == 'ortho': + init.orthogonal_(module.weight) + elif self.init == 'N02': + init.normal_(module.weight, 0, 0.02) + elif self.init in ['glorot', 'xavier']: + init.xavier_uniform_(module.weight) + else: + print('Init style not recognized...') + self.param_count += sum([p.data.nelement() for p in module.parameters()]) + print('Param count for D''s initialized parameters: %d' % self.param_count) + + def forward(self, x, y=None): + # Run input conv + h = self.input_conv(x) + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out diff --git a/text2image/BigGAN_utils/LICENSE b/text2image/BigGAN_utils/LICENSE new file mode 100644 index 0000000..c2dd721 --- /dev/null +++ b/text2image/BigGAN_utils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Andy Brock + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/text2image/BigGAN_utils/README.md b/text2image/BigGAN_utils/README.md new file mode 100644 index 0000000..3830c36 --- /dev/null +++ b/text2image/BigGAN_utils/README.md @@ -0,0 +1 @@ +Download pre-trained weights from (https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing) and put them in `./weights/` diff --git a/text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc b/text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc new file mode 100644 index 0000000..1101bce Binary files /dev/null and b/text2image/BigGAN_utils/TFHub/__pycache__/biggan_v1.cpython-38.pyc differ diff --git a/text2image/BigGAN_utils/TFHub/biggan_v1.py b/text2image/BigGAN_utils/TFHub/biggan_v1.py new file mode 100644 index 0000000..56e99b8 --- /dev/null +++ b/text2image/BigGAN_utils/TFHub/biggan_v1.py @@ -0,0 +1,389 @@ +# BigGAN V1: +# This is now deprecated code used for porting the TFHub modules to pytorch, +# included here for reference only. +import numpy as np +import torch +from scipy.stats import truncnorm +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F + + +def l2normalize(v, eps=1e-4): + return v / (v.norm() + eps) + + +def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) + return truncation * values + + +def denorm(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) + + +class SpectralNorm(nn.Module): + def __init__(self, module, name='weight', power_iterations=1): + super(SpectralNorm, self).__init__() + self.module = module + self.name = name + self.power_iterations = power_iterations + if not self._made_params(): + self._make_params() + + def _update_u_v(self): + u = getattr(self.module, self.name + "_u") + v = getattr(self.module, self.name + "_v") + w = getattr(self.module, self.name + "_bar") + + height = w.data.shape[0] + _w = w.view(height, -1) + for _ in range(self.power_iterations): + v = l2normalize(torch.matmul(_w.t(), u)) + u = l2normalize(torch.matmul(_w, v)) + + sigma = u.dot((_w).mv(v)) + setattr(self.module, self.name, w / sigma.expand_as(w)) + + def _made_params(self): + try: + getattr(self.module, self.name + "_u") + getattr(self.module, self.name + "_v") + getattr(self.module, self.name + "_bar") + return True + except AttributeError: + return False + + def _make_params(self): + w = getattr(self.module, self.name) + + height = w.data.shape[0] + width = w.view(height, -1).data.shape[1] + + u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + u.data = l2normalize(u.data) + v.data = l2normalize(v.data) + w_bar = Parameter(w.data) + + del self.module._parameters[self.name] + self.module.register_parameter(self.name + "_u", u) + self.module.register_parameter(self.name + "_v", v) + self.module.register_parameter(self.name + "_bar", w_bar) + + def forward(self, *args): + self._update_u_v() + return self.module.forward(*args) + + +class SelfAttention(nn.Module): + """ Self Attention Layer""" + + def __init__(self, in_dim, activation=F.relu): + super().__init__() + self.chanel_in = in_dim + self.activation = activation + + self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) + self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) + self.pool = nn.MaxPool2d(2, 2) + self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False)) + self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False)) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + m_batchsize, C, width, height = x.size() + N = height * width + + theta = self.theta(x) + phi = self.phi(x) + phi = self.pool(phi) + phi = phi.view(m_batchsize, -1, N // 4) + theta = theta.view(m_batchsize, -1, N) + theta = theta.permute(0, 2, 1) + attention = self.softmax(torch.bmm(theta, phi)) + g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4) + attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height) + out = self.o_conv(attn_g) + return self.gamma * out + x + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) + self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) + self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) + + def forward(self, x, y): + out = self.bn(x) + gamma = self.gamma_embed(y) + 1 + beta = self.beta_embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) + return out + + +class GBlock(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size=[3, 3], + padding=1, + stride=1, + n_class=None, + bn=True, + activation=F.relu, + upsample=True, + downsample=False, + z_dim=148, + ): + super().__init__() + + self.conv0 = SpectralNorm( + nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) + ) + self.conv1 = SpectralNorm( + nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) + ) + + self.skip_proj = False + if in_channel != out_channel or upsample or downsample: + self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) + self.skip_proj = True + + self.upsample = upsample + self.downsample = downsample + self.activation = activation + self.bn = bn + if bn: + self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim) + self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim) + + def forward(self, input, condition=None): + out = input + + if self.bn: + out = self.HyperBN(out, condition) + out = self.activation(out) + if self.upsample: + out = F.interpolate(out, scale_factor=2) + out = self.conv0(out) + if self.bn: + out = self.HyperBN_1(out, condition) + out = self.activation(out) + out = self.conv1(out) + + if self.downsample: + out = F.avg_pool2d(out, 2) + + if self.skip_proj: + skip = input + if self.upsample: + skip = F.interpolate(skip, scale_factor=2) + skip = self.conv_sc(skip) + if self.downsample: + skip = F.avg_pool2d(skip, 2) + else: + skip = input + return out + skip + + +class Generator128(nn.Module): + def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) + + z_dim = code_dim + 28 + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), + GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), + GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), + GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + ]) + + self.sa_id = 4 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(2 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Generator256(nn.Module): + def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class), + GBlock(16 * chn, 8 * chn, n_class=n_class), + GBlock(8 * chn, 8 * chn, n_class=n_class), + GBlock(8 * chn, 4 * chn, n_class=n_class), + GBlock(4 * chn, 2 * chn, n_class=n_class), + GBlock(2 * chn, 1 * chn, n_class=n_class), + ]) + + self.sa_id = 5 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(2 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Generator512(nn.Module): + def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False): + super().__init__() + + self.linear = nn.Linear(n_class, 128, bias=False) + + if debug: + chn = 8 + + self.first_view = 16 * chn + + self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn)) + + z_dim = code_dim + 16 + + self.GBlock = nn.ModuleList([ + GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), + GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), + GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), + GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), + GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), + ]) + + self.sa_id = 4 + self.num_split = len(self.GBlock) + 1 + self.attention = SelfAttention(4 * chn) + self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn) + self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) + + def forward(self, input, class_id): + codes = torch.chunk(input, self.num_split, 1) + class_emb = self.linear(class_id) # 128 + + out = self.G_linear(codes[0]) + out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) + for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): + if i == self.sa_id: + out = self.attention(out) + condition = torch.cat([code, class_emb], 1) + out = GBlock(out, condition) + + out = self.ScaledCrossReplicaBN(out) + out = F.relu(out) + out = self.colorize(out) + return torch.tanh(out) + + +class Discriminator(nn.Module): + def __init__(self, n_class=1000, chn=96, debug=False): + super().__init__() + + def conv(in_channel, out_channel, downsample=True): + return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample) + + if debug: + chn = 8 + self.debug = debug + + self.pre_conv = nn.Sequential( + SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)), + nn.ReLU(), + SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)), + nn.AvgPool2d(2), + ) + self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1)) + + self.conv = nn.Sequential( + conv(1 * chn, 1 * chn, downsample=True), + conv(1 * chn, 2 * chn, downsample=True), + SelfAttention(2 * chn), + conv(2 * chn, 2 * chn, downsample=True), + conv(2 * chn, 4 * chn, downsample=True), + conv(4 * chn, 8 * chn, downsample=True), + conv(8 * chn, 8 * chn, downsample=True), + conv(8 * chn, 16 * chn, downsample=True), + conv(16 * chn, 16 * chn, downsample=False), + ) + + self.linear = SpectralNorm(nn.Linear(16 * chn, 1)) + + self.embed = nn.Embedding(n_class, 16 * chn) + self.embed.weight.data.uniform_(-0.1, 0.1) + self.embed = SpectralNorm(self.embed) + + def forward(self, input, class_id): + + out = self.pre_conv(input) + out += self.pre_skip(F.avg_pool2d(input, 2)) + out = self.conv(out) + out = F.relu(out) + out = out.view(out.size(0), out.size(1), -1) + out = out.sum(2) + out_linear = self.linear(out).squeeze(1) + embed = self.embed(class_id) + + prod = (out * embed).sum(1) + + return out_linear + prod diff --git a/text2image/BigGAN_utils/TFHub/converter.py b/text2image/BigGAN_utils/TFHub/converter.py new file mode 100644 index 0000000..a0aa322 --- /dev/null +++ b/text2image/BigGAN_utils/TFHub/converter.py @@ -0,0 +1,396 @@ +"""Utilities for converting TFHub BigGAN generator weights to PyTorch. +Recommended usage: +To convert all BigGAN variants and generate test samples, use: +```bash +CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples +``` +See `parse_args` for additional options. +""" + +import argparse +import os +import sys + +import h5py +import torch +import torch.nn as nn +from torchvision.utils import save_image +import tensorflow as tf +import tensorflow_hub as hub +import parse + +# import reference biggan from this folder +import biggan_v1 as biggan_for_conversion + +# Import model from main folder +sys.path.append('..') +import BigGAN + + + + +DEVICE = 'cuda' +HDF5_TMPL = 'biggan-{}.h5' +PTH_TMPL = 'biggan-{}.pth' +MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2' +Z_DIMS = { + 128: 120, + 256: 140, + 512: 128} +RESOLUTIONS = list(Z_DIMS) + + +def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False): + """Loads TFHub weights and saves them to intermediate HDF5 file. + Args: + module_path ([Path-like]): Path to TFHub module. + hdf5_path ([Path-like]): Path to output HDF5 file. + Returns: + [h5py.File]: Loaded hdf5 file containing module weights. + """ + if os.path.exists(hdf5_path) and (not redownload): + print('Loading BigGAN hdf5 file from:', hdf5_path) + return h5py.File(hdf5_path, 'r') + + print('Loading BigGAN module from:', module_path) + tf.reset_default_graph() + hub.Module(module_path) + print('Loaded BigGAN module from:', module_path) + + initializer = tf.global_variables_initializer() + sess = tf.Session() + sess.run(initializer) + + print('Saving BigGAN weights to :', hdf5_path) + h5f = h5py.File(hdf5_path, 'w') + for var in tf.global_variables(): + val = sess.run(var) + h5f.create_dataset(var.name, data=val) + print(f'Saving {var.name} with shape {val.shape}') + h5f.close() + return h5py.File(hdf5_path, 'r') + + +class TFHub2Pytorch(object): + + TF_ROOT = 'module' + + NUM_GBLOCK = { + 128: 5, + 256: 6, + 512: 7 + } + + w = 'w' + b = 'b' + u = 'u0' + v = 'u1' + gamma = 'gamma' + beta = 'beta' + + def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False): + self.state_dict = state_dict + self.tf_weights = tf_weights + self.resolution = resolution + self.verbose = verbose + if load_ema: + for name in ['w', 'b', 'gamma', 'beta']: + setattr(self, name, getattr(self, name) + '/ema_b999900') + + def load(self): + self.load_generator() + return self.state_dict + + def load_generator(self): + GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator') + + for i in range(self.NUM_GBLOCK[self.resolution]): + name_tf = os.path.join(GENERATOR_ROOT, 'GBlock') + name_tf += f'_{i}' if i != 0 else '' + self.load_GBlock(f'GBlock.{i}.', name_tf) + + self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention')) + self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False) + self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear')) + self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d')) + self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN', + os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN')) + + def load_linear(self, name_pth, name_tf, bias=True): + self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) + if bias: + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_snlinear(self, name_pth, name_tf, bias=True): + self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) + if bias: + self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_colorize(self, name_pth, name_tf): + self.load_snconv(name_pth, name_tf) + + def load_GBlock(self, name_pth, name_tf): + self.load_convs(name_pth, name_tf) + self.load_HyperBNs(name_pth, name_tf) + + def load_convs(self, name_pth, name_tf): + self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0')) + self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1')) + self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc')) + + def load_snconv(self, name_pth, name_tf, bias=True): + if self.verbose: + print(f'loading: {name_pth} from {name_tf}') + self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) + if bias: + self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze() + + def load_conv(self, name_pth, name_tf, bias=True): + + self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() + self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() + self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) + if bias: + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) + + def load_HyperBNs(self, name_pth, name_tf): + self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN')) + self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1')) + + def load_ScaledCrossReplicaBNs(self, name_pth, name_tf): + self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze() + self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze() + self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean') + self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var') + self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor( + self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32) + + def load_HyperBN(self, name_pth, name_tf): + if self.verbose: + print(f'loading: {name_pth} from {name_tf}') + beta = name_pth + '.beta_embed.module' + gamma = name_pth + '.gamma_embed.module' + self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze() + self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze() + self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze() + self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze() + self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0) + self.state_dict[gamma + + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0) + + cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN') + self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean') + self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var') + self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor( + self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32) + + def load_attention(self, name_pth, name_tf): + + self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False) + self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False) + self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False) + self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False) + self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma) + + def load_tf_tensor(self, prefix, var, device='0'): + name = os.path.join(prefix, var) + f':{device}' + return torch.from_numpy(self.tf_weights[name][:]) + +# Convert from v1: This function maps +def convert_from_v1(hub_dict, resolution=128): + weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'} + convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'} + attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution] + hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight + # Linear stuff + 'G_linear.module.weight_bar': 'linear.weight', + 'G_linear.module.bias': 'linear.bias', + 'G_linear.module.weight_u': 'linear.u0', + # output layer stuff + 'ScaledCrossReplicaBN.weight': 'output_layer.0.gain', + 'ScaledCrossReplicaBN.bias': 'output_layer.0.bias', + 'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean', + 'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var', + 'colorize.module.weight_bar': 'output_layer.2.weight', + 'colorize.module.bias': 'output_layer.2.bias', + 'colorize.module.weight_u': 'output_layer.2.u0', + # Attention stuff + 'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum, + 'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum, + 'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum, + 'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum, + 'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum, + 'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum, + 'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum, + 'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum, + 'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum, + } + + # Loop over the hub dict and build the hub2me map + for name in hub_dict.keys(): + if 'GBlock' in name: + if 'HyperBN' not in name: # it's a conv + out = parse.parse('GBlock.{:d}.{}.module.{}',name) + blocknum, convnum, weightname = out + if weightname not in weightname_dict: + continue # else hyperBN in + out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1 + else: # hyperbn not conv + BNnum = 2 if 'HyperBN_1' in name else 1 + if 'embed' in name: + out = parse.parse('GBlock.{:d}.{}.module.{}',name) + blocknum, gamma_or_beta, weightname = out + if weightname not in weightname_dict: # Ignore weight_v + continue + out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname]) + else: + out = parse.parse('GBlock.{:d}.{}.bn.{}',name) + blocknum, dummy, mean_or_var = out + if 'num_batches_tracked' in mean_or_var: + continue + out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var') + hub2me[name] = out_name + + + # Invert the hub2me map + me2hub = {hub2me[item]: item for item in hub2me} + new_dict = {} + dimz_dict = {128: 20, 256: 20, 512:16} + for item in me2hub: + # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs + if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item): + new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1) + # Reshape the first linear weight, bias, and u0 + elif item == 'linear.weight': + new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution]) + elif item == 'linear.bias': + new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1) + elif item == 'linear.u0': + new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1) + elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER + # Transpose shared weight so that it's an embedding + new_dict[item] = hub_dict[me2hub[item]].t() + elif 'weight_u' in me2hub[item]: # Unsqueeze u0s + new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0) + else: + new_dict[item] = hub_dict[me2hub[item]] + return new_dict + +def get_config(resolution): + attn_dict = {128: '64', 256: '128', 512: '64'} + dim_z_dict = {128: 120, 256: 140, 512: 128} + config = {'G_param': 'SN', 'D_param': 'SN', + 'G_ch': 96, 'D_ch': 96, + 'D_wide': True, 'G_shared': True, + 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], + 'hier': True, 'cross_replica': False, + 'mybn': False, 'G_activation': nn.ReLU(inplace=True), + 'G_attn': attn_dict[resolution], + 'norm_style': 'bn', + 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, + 'G_fp16': False, 'G_mixed_precision': False, + 'accumulate_stats': False, 'num_standing_accumulations': 16, + 'G_eval_mode': True, + 'BN_eps': 1e-04, 'SN_eps': 1e-04, + 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, + 'n_classes': 1000} + return config + + +def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False): + module_path = MODULE_PATH_TMPL.format(resolution) + hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution)) + pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution)) + + tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload) + G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')() + state_dict_temp = G_temp.state_dict() + + converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution, + load_ema=(not no_ema), verbose=verbose) + state_dict_v1 = converter.load() + state_dict = convert_from_v1(state_dict_v1, resolution) + # Get the config, build the model + config = get_config(resolution) + G = BigGAN.Generator(**config) + G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries + torch.save(state_dict, pth_path) + + # output_location ='pretrained_weights/TFHub-PyTorch-128.pth' + + return G + + +def generate_sample(G, z_dim, batch_size, filename, parallel=False): + + G.eval() + G.to(DEVICE) + with torch.no_grad(): + z = torch.randn(batch_size, G.dim_z).to(DEVICE) + y = torch.randint(low=0, high=1000, size=(batch_size,), + device=DEVICE, dtype=torch.int64, requires_grad=False) + if parallel: + images = nn.parallel.data_parallel(G, (z, G.shared(y))) + else: + images = G(z, G.shared(y)) + save_image(images, filename, scale_each=True, normalize=True) + +def parse_args(): + usage = 'Parser for conversion script.' + parser = argparse.ArgumentParser(description=usage) + parser.add_argument( + '--resolution', '-r', type=int, default=None, choices=[128, 256, 512], + help='Resolution of TFHub module to convert. Converts all resolutions if None.') + parser.add_argument( + '--redownload', action='store_true', default=False, + help='Redownload weights and overwrite current hdf5 file, if present.') + parser.add_argument( + '--weights_dir', type=str, default='pretrained_weights') + parser.add_argument( + '--samples_dir', type=str, default='pretrained_samples') + parser.add_argument( + '--no_ema', action='store_true', default=False, + help='Do not load ema weights.') + parser.add_argument( + '--verbose', action='store_true', default=False, + help='Additionally logging.') + parser.add_argument( + '--generate_samples', action='store_true', default=False, + help='Generate test sample with pretrained model.') + parser.add_argument( + '--batch_size', type=int, default=64, + help='Batch size used for test sample.') + parser.add_argument( + '--parallel', action='store_true', default=False, + help='Parallelize G?') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + + args = parse_args() + os.makedirs(args.weights_dir, exist_ok=True) + os.makedirs(args.samples_dir, exist_ok=True) + + if args.resolution is not None: + G = convert_biggan(args.resolution, args.weights_dir, + redownload=args.redownload, + no_ema=args.no_ema, verbose=args.verbose) + if args.generate_samples: + filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg') + print('Generating samples...') + generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel) + else: + for res in RESOLUTIONS: + G = convert_biggan(res, args.weights_dir, + redownload=args.redownload, + no_ema=args.no_ema, verbose=args.verbose) + if args.generate_samples: + filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg') + print('Generating samples...') + generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel) diff --git a/text2image/BigGAN_utils/__init__.py b/text2image/BigGAN_utils/__init__.py new file mode 100644 index 0000000..493d56b --- /dev/null +++ b/text2image/BigGAN_utils/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.append('./BigGAN_utils/') diff --git a/text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc new file mode 100644 index 0000000..4a6ab4b Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/BigGAN.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/__pycache__/__init__.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..8543bbb Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/__pycache__/animal_hash.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/animal_hash.cpython-37.pyc new file mode 100644 index 0000000..6669dd8 Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/animal_hash.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc new file mode 100644 index 0000000..7bf77c6 Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/datasets.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc new file mode 100644 index 0000000..9f6bf9a Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/layers.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc b/text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..ef37c33 Binary files /dev/null and b/text2image/BigGAN_utils/__pycache__/utils.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/animal_hash.py b/text2image/BigGAN_utils/animal_hash.py new file mode 100644 index 0000000..66c2be2 --- /dev/null +++ b/text2image/BigGAN_utils/animal_hash.py @@ -0,0 +1,439 @@ +c = ['Aardvark', 'Abyssinian', 'Affenpinscher', 'Akbash', 'Akita', 'Albatross', + 'Alligator', 'Alpaca', 'Angelfish', 'Ant', 'Anteater', 'Antelope', 'Ape', + 'Armadillo', 'Ass', 'Avocet', 'Axolotl', 'Baboon', 'Badger', 'Balinese', + 'Bandicoot', 'Barb', 'Barnacle', 'Barracuda', 'Bat', 'Beagle', 'Bear', + 'Beaver', 'Bee', 'Beetle', 'Binturong', 'Bird', 'Birman', 'Bison', + 'Bloodhound', 'Boar', 'Bobcat', 'Bombay', 'Bongo', 'Bonobo', 'Booby', + 'Budgerigar', 'Buffalo', 'Bulldog', 'Bullfrog', 'Burmese', 'Butterfly', + 'Caiman', 'Camel', 'Capybara', 'Caracal', 'Caribou', 'Cassowary', 'Cat', + 'Caterpillar', 'Catfish', 'Cattle', 'Centipede', 'Chameleon', 'Chamois', + 'Cheetah', 'Chicken', 'Chihuahua', 'Chimpanzee', 'Chinchilla', 'Chinook', + 'Chipmunk', 'Chough', 'Cichlid', 'Clam', 'Coati', 'Cobra', 'Cockroach', + 'Cod', 'Collie', 'Coral', 'Cormorant', 'Cougar', 'Cow', 'Coyote', + 'Crab', 'Crane', 'Crocodile', 'Crow', 'Curlew', 'Cuscus', 'Cuttlefish', + 'Dachshund', 'Dalmatian', 'Deer', 'Dhole', 'Dingo', 'Dinosaur', 'Discus', + 'Dodo', 'Dog', 'Dogball', 'Dogfish', 'Dolphin', 'Donkey', 'Dormouse', + 'Dove', 'Dragonfly', 'Drever', 'Duck', 'Dugong', 'Dunker', 'Dunlin', + 'Eagle', 'Earwig', 'Echidna', 'Eel', 'Eland', 'Elephant', 'ElephantSeal', + 'Elk', 'Emu', 'Falcon', 'Ferret', 'Finch', 'Fish', 'Flamingo', 'Flounder', + 'Fly', 'Fossa', 'Fox', 'Frigatebird', 'Frog', 'Galago', 'Gar', 'Gaur', + 'Gazelle', 'Gecko', 'Gerbil', 'Gharial', 'GiantPanda', 'Gibbon', 'Giraffe', + 'Gnat', 'Gnu', 'Goat', 'Goldfinch', 'Goldfish', 'Goose', 'Gopher', + 'Gorilla', 'Goshawk', 'Grasshopper', 'Greyhound', 'Grouse', 'Guanaco', + 'GuineaFowl', 'GuineaPig', 'Gull', 'Guppy', 'Hamster', 'Hare', 'Harrier', + 'Havanese', 'Hawk', 'Hedgehog', 'Heron', 'Herring', 'Himalayan', + 'Hippopotamus', 'Hornet', 'Horse', 'Human', 'Hummingbird', 'Hyena', + 'Ibis', 'Iguana', 'Impala', 'Indri', 'Insect', 'Jackal', 'Jaguar', + 'Javanese', 'Jay', 'Jellyfish', 'Kakapo', 'Kangaroo', 'Kingfisher', + 'Kiwi', 'Koala', 'KomodoDragon', 'Kouprey', 'Kudu', 'Labradoodle', + 'Ladybird', 'Lapwing', 'Lark', 'Lemming', 'Lemur', 'Leopard', 'Liger', + 'Lion', 'Lionfish', 'Lizard', 'Llama', 'Lobster', 'Locust', 'Loris', + 'Louse', 'Lynx', 'Lyrebird', 'Macaw', 'Magpie', 'Mallard', 'Maltese', + 'Manatee', 'Mandrill', 'Markhor', 'Marten', 'Mastiff', 'Mayfly', 'Meerkat', + 'Millipede', 'Mink', 'Mole', 'Molly', 'Mongoose', 'Mongrel', 'Monkey', + 'Moorhen', 'Moose', 'Mosquito', 'Moth', 'Mouse', 'Mule', 'Narwhal', + 'Neanderthal', 'Newfoundland', 'Newt', 'Nightingale', 'Numbat', 'Ocelot', + 'Octopus', 'Okapi', 'Olm', 'Opossum', 'Orang-utan', 'Oryx', 'Ostrich', + 'Otter', 'Owl', 'Ox', 'Oyster', 'Pademelon', 'Panther', 'Parrot', + 'Partridge', 'Peacock', 'Peafowl', 'Pekingese', 'Pelican', 'Penguin', + 'Persian', 'Pheasant', 'Pig', 'Pigeon', 'Pika', 'Pike', 'Piranha', + 'Platypus', 'Pointer', 'Pony', 'Poodle', 'Porcupine', 'Porpoise', + 'Possum', 'PrairieDog', 'Prawn', 'Puffin', 'Pug', 'Puma', 'Quail', + 'Quelea', 'Quetzal', 'Quokka', 'Quoll', 'Rabbit', 'Raccoon', 'Ragdoll', + 'Rail', 'Ram', 'Rat', 'Rattlesnake', 'Raven', 'RedDeer', 'RedPanda', + 'Reindeer', 'Rhinoceros', 'Robin', 'Rook', 'Rottweiler', 'Ruff', + 'Salamander', 'Salmon', 'SandDollar', 'Sandpiper', 'Saola', + 'Sardine', 'Scorpion', 'SeaLion', 'SeaUrchin', 'Seahorse', + 'Seal', 'Serval', 'Shark', 'Sheep', 'Shrew', 'Shrimp', 'Siamese', + 'Siberian', 'Skunk', 'Sloth', 'Snail', 'Snake', 'Snowshoe', 'Somali', + 'Sparrow', 'Spider', 'Sponge', 'Squid', 'Squirrel', 'Starfish', 'Starling', + 'Stingray', 'Stinkbug', 'Stoat', 'Stork', 'Swallow', 'Swan', 'Tang', + 'Tapir', 'Tarsier', 'Termite', 'Tetra', 'Tiffany', 'Tiger', 'Toad', + 'Tortoise', 'Toucan', 'Tropicbird', 'Trout', 'Tuatara', 'Turkey', + 'Turtle', 'Uakari', 'Uguisu', 'Umbrellabird', 'Viper', 'Vulture', + 'Wallaby', 'Walrus', 'Warthog', 'Wasp', 'WaterBuffalo', 'Weasel', + 'Whale', 'Whippet', 'Wildebeest', 'Wolf', 'Wolverine', 'Wombat', + 'Woodcock', 'Woodlouse', 'Woodpecker', 'Worm', 'Wrasse', 'Wren', + 'Yak', 'Zebra', 'Zebu', 'Zonkey'] +a = ['able', 'above', 'absent', 'absolute', 'abstract', 'abundant', 'academic', + 'acceptable', 'accepted', 'accessible', 'accurate', 'accused', 'active', + 'actual', 'acute', 'added', 'additional', 'adequate', 'adjacent', + 'administrative', 'adorable', 'advanced', 'adverse', 'advisory', + 'aesthetic', 'afraid', 'african', 'aggregate', 'aggressive', 'agreeable', + 'agreed', 'agricultural', 'alert', 'alive', 'alleged', 'allied', 'alone', + 'alright', 'alternative', 'amateur', 'amazing', 'ambitious', 'american', + 'amused', 'ancient', 'angry', 'annoyed', 'annual', 'anonymous', 'anxious', + 'appalling', 'apparent', 'applicable', 'appropriate', 'arab', 'arbitrary', + 'architectural', 'armed', 'arrogant', 'artificial', 'artistic', 'ashamed', + 'asian', 'asleep', 'assistant', 'associated', 'atomic', 'attractive', + 'australian', 'automatic', 'autonomous', 'available', 'average', + 'awake', 'aware', 'awful', 'awkward', 'back', 'bad', 'balanced', 'bare', + 'basic', 'beautiful', 'beneficial', 'better', 'bewildered', 'big', + 'binding', 'biological', 'bitter', 'bizarre', 'black', 'blank', 'blind', + 'blonde', 'bloody', 'blue', 'blushing', 'boiling', 'bold', 'bored', + 'boring', 'bottom', 'brainy', 'brave', 'breakable', 'breezy', 'brief', + 'bright', 'brilliant', 'british', 'broad', 'broken', 'brown', 'bumpy', + 'burning', 'busy', 'calm', 'canadian', 'capable', 'capitalist', 'careful', + 'casual', 'catholic', 'causal', 'cautious', 'central', 'certain', + 'changing', 'characteristic', 'charming', 'cheap', 'cheerful', 'chemical', + 'chief', 'chilly', 'chinese', 'chosen', 'christian', 'chronic', 'chubby', + 'circular', 'civic', 'civil', 'civilian', 'classic', 'classical', 'clean', + 'clear', 'clever', 'clinical', 'close', 'closed', 'cloudy', 'clumsy', + 'coastal', 'cognitive', 'coherent', 'cold', 'collective', 'colonial', + 'colorful', 'colossal', 'coloured', 'colourful', 'combative', 'combined', + 'comfortable', 'coming', 'commercial', 'common', 'communist', 'compact', + 'comparable', 'comparative', 'compatible', 'competent', 'competitive', + 'complete', 'complex', 'complicated', 'comprehensive', 'compulsory', + 'conceptual', 'concerned', 'concrete', 'condemned', 'confident', + 'confidential', 'confused', 'conscious', 'conservation', 'conservative', + 'considerable', 'consistent', 'constant', 'constitutional', + 'contemporary', 'content', 'continental', 'continued', 'continuing', + 'continuous', 'controlled', 'controversial', 'convenient', 'conventional', + 'convinced', 'convincing', 'cooing', 'cool', 'cooperative', 'corporate', + 'correct', 'corresponding', 'costly', 'courageous', 'crazy', 'creative', + 'creepy', 'criminal', 'critical', 'crooked', 'crowded', 'crucial', + 'crude', 'cruel', 'cuddly', 'cultural', 'curious', 'curly', 'current', + 'curved', 'cute', 'daily', 'damaged', 'damp', 'dangerous', 'dark', 'dead', + 'deaf', 'deafening', 'dear', 'decent', 'decisive', 'deep', 'defeated', + 'defensive', 'defiant', 'definite', 'deliberate', 'delicate', 'delicious', + 'delighted', 'delightful', 'democratic', 'dependent', 'depressed', + 'desirable', 'desperate', 'detailed', 'determined', 'developed', + 'developing', 'devoted', 'different', 'difficult', 'digital', 'diplomatic', + 'direct', 'dirty', 'disabled', 'disappointed', 'disastrous', + 'disciplinary', 'disgusted', 'distant', 'distinct', 'distinctive', + 'distinguished', 'disturbed', 'disturbing', 'diverse', 'divine', 'dizzy', + 'domestic', 'dominant', 'double', 'doubtful', 'drab', 'dramatic', + 'dreadful', 'driving', 'drunk', 'dry', 'dual', 'due', 'dull', 'dusty', + 'dutch', 'dying', 'dynamic', 'eager', 'early', 'eastern', 'easy', + 'economic', 'educational', 'eerie', 'effective', 'efficient', + 'elaborate', 'elated', 'elderly', 'eldest', 'electoral', 'electric', + 'electrical', 'electronic', 'elegant', 'eligible', 'embarrassed', + 'embarrassing', 'emotional', 'empirical', 'empty', 'enchanting', + 'encouraging', 'endless', 'energetic', 'english', 'enormous', + 'enthusiastic', 'entire', 'entitled', 'envious', 'environmental', 'equal', + 'equivalent', 'essential', 'established', 'estimated', 'ethical', + 'ethnic', 'european', 'eventual', 'everyday', 'evident', 'evil', + 'evolutionary', 'exact', 'excellent', 'exceptional', 'excess', + 'excessive', 'excited', 'exciting', 'exclusive', 'existing', 'exotic', + 'expected', 'expensive', 'experienced', 'experimental', 'explicit', + 'extended', 'extensive', 'external', 'extra', 'extraordinary', 'extreme', + 'exuberant', 'faint', 'fair', 'faithful', 'familiar', 'famous', 'fancy', + 'fantastic', 'far', 'fascinating', 'fashionable', 'fast', 'fat', 'fatal', + 'favourable', 'favourite', 'federal', 'fellow', 'female', 'feminist', + 'few', 'fierce', 'filthy', 'final', 'financial', 'fine', 'firm', 'fiscal', + 'fit', 'fixed', 'flaky', 'flat', 'flexible', 'fluffy', 'fluttering', + 'flying', 'following', 'fond', 'foolish', 'foreign', 'formal', + 'formidable', 'forthcoming', 'fortunate', 'forward', 'fragile', + 'frail', 'frantic', 'free', 'french', 'frequent', 'fresh', 'friendly', + 'frightened', 'front', 'frozen', 'fucking', 'full', 'full-time', 'fun', + 'functional', 'fundamental', 'funny', 'furious', 'future', 'fuzzy', + 'gastric', 'gay', 'general', 'generous', 'genetic', 'gentle', 'genuine', + 'geographical', 'german', 'giant', 'gigantic', 'given', 'glad', + 'glamorous', 'gleaming', 'global', 'glorious', 'golden', 'good', + 'gorgeous', 'gothic', 'governing', 'graceful', 'gradual', 'grand', + 'grateful', 'greasy', 'great', 'greek', 'green', 'grey', 'grieving', + 'grim', 'gross', 'grotesque', 'growing', 'grubby', 'grumpy', 'guilty', + 'handicapped', 'handsome', 'happy', 'hard', 'harsh', 'head', 'healthy', + 'heavy', 'helpful', 'helpless', 'hidden', 'high', 'high-pitched', + 'hilarious', 'hissing', 'historic', 'historical', 'hollow', 'holy', + 'homeless', 'homely', 'hon', 'honest', 'horizontal', 'horrible', + 'hostile', 'hot', 'huge', 'human', 'hungry', 'hurt', 'hushed', 'husky', + 'icy', 'ideal', 'identical', 'ideological', 'ill', 'illegal', + 'imaginative', 'immediate', 'immense', 'imperial', 'implicit', + 'important', 'impossible', 'impressed', 'impressive', 'improved', + 'inadequate', 'inappropriate', 'inc', 'inclined', 'increased', + 'increasing', 'incredible', 'independent', 'indian', 'indirect', + 'individual', 'industrial', 'inevitable', 'influential', 'informal', + 'inherent', 'initial', 'injured', 'inland', 'inner', 'innocent', + 'innovative', 'inquisitive', 'instant', 'institutional', 'insufficient', + 'intact', 'integral', 'integrated', 'intellectual', 'intelligent', + 'intense', 'intensive', 'interested', 'interesting', 'interim', + 'interior', 'intermediate', 'internal', 'international', 'intimate', + 'invisible', 'involved', 'iraqi', 'irish', 'irrelevant', 'islamic', + 'isolated', 'israeli', 'italian', 'itchy', 'japanese', 'jealous', + 'jewish', 'jittery', 'joint', 'jolly', 'joyous', 'judicial', 'juicy', + 'junior', 'just', 'keen', 'key', 'kind', 'known', 'korean', 'labour', + 'large', 'large-scale', 'late', 'latin', 'lazy', 'leading', 'left', + 'legal', 'legislative', 'legitimate', 'lengthy', 'lesser', 'level', + 'lexical', 'liable', 'liberal', 'light', 'like', 'likely', 'limited', + 'linear', 'linguistic', 'liquid', 'literary', 'little', 'live', 'lively', + 'living', 'local', 'logical', 'lonely', 'long', 'long-term', 'loose', + 'lost', 'loud', 'lovely', 'low', 'loyal', 'ltd', 'lucky', 'mad', + 'magenta', 'magic', 'magnetic', 'magnificent', 'main', 'major', 'male', + 'mammoth', 'managerial', 'managing', 'manual', 'many', 'marginal', + 'marine', 'marked', 'married', 'marvellous', 'marxist', 'mass', 'massive', + 'mathematical', 'mature', 'maximum', 'mean', 'meaningful', 'mechanical', + 'medical', 'medieval', 'melodic', 'melted', 'mental', 'mere', + 'metropolitan', 'mid', 'middle', 'middle-class', 'mighty', 'mild', + 'military', 'miniature', 'minimal', 'minimum', 'ministerial', 'minor', + 'miserable', 'misleading', 'missing', 'misty', 'mixed', 'moaning', + 'mobile', 'moderate', 'modern', 'modest', 'molecular', 'monetary', + 'monthly', 'moral', 'motionless', 'muddy', 'multiple', 'mushy', + 'musical', 'mute', 'mutual', 'mysterious', 'naked', 'narrow', 'nasty', + 'national', 'native', 'natural', 'naughty', 'naval', 'near', 'nearby', + 'neat', 'necessary', 'negative', 'neighbouring', 'nervous', 'net', + 'neutral', 'new', 'nice', 'nineteenth-century', 'noble', 'noisy', + 'normal', 'northern', 'nosy', 'notable', 'novel', 'nuclear', 'numerous', + 'nursing', 'nutritious', 'nutty', 'obedient', 'objective', 'obliged', + 'obnoxious', 'obvious', 'occasional', 'occupational', 'odd', 'official', + 'ok', 'okay', 'old', 'old-fashioned', 'olympic', 'only', 'open', + 'operational', 'opposite', 'optimistic', 'oral', 'orange', 'ordinary', + 'organic', 'organisational', 'original', 'orthodox', 'other', 'outdoor', + 'outer', 'outrageous', 'outside', 'outstanding', 'overall', 'overseas', + 'overwhelming', 'painful', 'pale', 'palestinian', 'panicky', 'parallel', + 'parental', 'parliamentary', 'part-time', 'partial', 'particular', + 'passing', 'passive', 'past', 'patient', 'payable', 'peaceful', + 'peculiar', 'perfect', 'permanent', 'persistent', 'personal', 'petite', + 'philosophical', 'physical', 'pink', 'plain', 'planned', 'plastic', + 'pleasant', 'pleased', 'poised', 'polish', 'polite', 'political', 'poor', + 'popular', 'positive', 'possible', 'post-war', 'potential', 'powerful', + 'practical', 'precious', 'precise', 'preferred', 'pregnant', + 'preliminary', 'premier', 'prepared', 'present', 'presidential', + 'pretty', 'previous', 'prickly', 'primary', 'prime', 'primitive', + 'principal', 'printed', 'prior', 'private', 'probable', 'productive', + 'professional', 'profitable', 'profound', 'progressive', 'prominent', + 'promising', 'proper', 'proposed', 'prospective', 'protective', + 'protestant', 'proud', 'provincial', 'psychiatric', 'psychological', + 'public', 'puny', 'pure', 'purple', 'purring', 'puzzled', 'quaint', + 'qualified', 'quick', 'quickest', 'quiet', 'racial', 'radical', 'rainy', + 'random', 'rapid', 'rare', 'raspy', 'rational', 'ratty', 'raw', 'ready', + 'real', 'realistic', 'rear', 'reasonable', 'recent', 'red', 'reduced', + 'redundant', 'regional', 'registered', 'regular', 'regulatory', 'related', + 'relative', 'relaxed', 'relevant', 'reliable', 'relieved', 'religious', + 'reluctant', 'remaining', 'remarkable', 'remote', 'renewed', + 'representative', 'repulsive', 'required', 'resident', 'residential', + 'resonant', 'respectable', 'respective', 'responsible', 'resulting', + 'retail', 'retired', 'revolutionary', 'rich', 'ridiculous', 'right', + 'rigid', 'ripe', 'rising', 'rival', 'roasted', 'robust', 'rolling', + 'roman', 'romantic', 'rotten', 'rough', 'round', 'royal', 'rubber', + 'rude', 'ruling', 'running', 'rural', 'russian', 'sacred', 'sad', 'safe', + 'salty', 'satisfactory', 'satisfied', 'scared', 'scary', 'scattered', + 'scientific', 'scornful', 'scottish', 'scrawny', 'screeching', + 'secondary', 'secret', 'secure', 'select', 'selected', 'selective', + 'selfish', 'semantic', 'senior', 'sensible', 'sensitive', 'separate', + 'serious', 'severe', 'sexual', 'shaggy', 'shaky', 'shallow', 'shared', + 'sharp', 'sheer', 'shiny', 'shivering', 'shocked', 'short', 'short-term', + 'shrill', 'shy', 'sick', 'significant', 'silent', 'silky', 'silly', + 'similar', 'simple', 'single', 'skilled', 'skinny', 'sleepy', 'slight', + 'slim', 'slimy', 'slippery', 'slow', 'small', 'smart', 'smiling', + 'smoggy', 'smooth', 'so-called', 'social', 'socialist', 'soft', 'solar', + 'sole', 'solid', 'sophisticated', 'sore', 'sorry', 'sound', 'sour', + 'southern', 'soviet', 'spanish', 'spare', 'sparkling', 'spatial', + 'special', 'specific', 'specified', 'spectacular', 'spicy', 'spiritual', + 'splendid', 'spontaneous', 'sporting', 'spotless', 'spotty', 'square', + 'squealing', 'stable', 'stale', 'standard', 'static', 'statistical', + 'statutory', 'steady', 'steep', 'sticky', 'stiff', 'still', 'stingy', + 'stormy', 'straight', 'straightforward', 'strange', 'strategic', + 'strict', 'striking', 'striped', 'strong', 'structural', 'stuck', + 'stupid', 'subjective', 'subsequent', 'substantial', 'subtle', + 'successful', 'successive', 'sudden', 'sufficient', 'suitable', + 'sunny', 'super', 'superb', 'superior', 'supporting', 'supposed', + 'supreme', 'sure', 'surprised', 'surprising', 'surrounding', + 'surviving', 'suspicious', 'sweet', 'swift', 'swiss', 'symbolic', + 'sympathetic', 'systematic', 'tall', 'tame', 'tan', 'tart', + 'tasteless', 'tasty', 'technical', 'technological', 'teenage', + 'temporary', 'tender', 'tense', 'terrible', 'territorial', 'testy', + 'then', 'theoretical', 'thick', 'thin', 'thirsty', 'thorough', + 'thoughtful', 'thoughtless', 'thundering', 'tight', 'tiny', 'tired', + 'top', 'tory', 'total', 'tough', 'toxic', 'traditional', 'tragic', + 'tremendous', 'tricky', 'tropical', 'troubled', 'turkish', 'typical', + 'ugliest', 'ugly', 'ultimate', 'unable', 'unacceptable', 'unaware', + 'uncertain', 'unchanged', 'uncomfortable', 'unconscious', 'underground', + 'underlying', 'unemployed', 'uneven', 'unexpected', 'unfair', + 'unfortunate', 'unhappy', 'uniform', 'uninterested', 'unique', 'united', + 'universal', 'unknown', 'unlikely', 'unnecessary', 'unpleasant', + 'unsightly', 'unusual', 'unwilling', 'upper', 'upset', 'uptight', + 'urban', 'urgent', 'used', 'useful', 'useless', 'usual', 'vague', + 'valid', 'valuable', 'variable', 'varied', 'various', 'varying', 'vast', + 'verbal', 'vertical', 'very', 'victorian', 'victorious', 'video-taped', + 'violent', 'visible', 'visiting', 'visual', 'vital', 'vivacious', + 'vivid', 'vocational', 'voiceless', 'voluntary', 'vulnerable', + 'wandering', 'warm', 'wasteful', 'watery', 'weak', 'wealthy', 'weary', + 'wee', 'weekly', 'weird', 'welcome', 'well', 'well-known', 'welsh', + 'western', 'wet', 'whispering', 'white', 'whole', 'wicked', 'wide', + 'wide-eyed', 'widespread', 'wild', 'willing', 'wise', 'witty', + 'wonderful', 'wooden', 'working', 'working-class', 'worldwide', + 'worried', 'worrying', 'worthwhile', 'worthy', 'written', 'wrong', + 'yellow', 'young', 'yummy', 'zany', 'zealous'] +b = ['abiding', 'accelerating', 'accepting', 'accomplishing', 'achieving', +'acquiring', 'acteding', 'activating', 'adapting', 'adding', 'addressing', +'administering', 'admiring', 'admiting', 'adopting', 'advising', 'affording', +'agreeing', 'alerting', 'alighting', 'allowing', 'altereding', 'amusing', +'analyzing', 'announcing', 'annoying', 'answering', 'anticipating', +'apologizing', 'appearing', 'applauding', 'applieding', 'appointing', + 'appraising', 'appreciating', 'approving', 'arbitrating', 'arguing', + 'arising', 'arranging', 'arresting', 'arriving', 'ascertaining', 'asking', + 'assembling', 'assessing', 'assisting', 'assuring', 'attaching', 'attacking', + 'attaining', 'attempting', 'attending', 'attracting', 'auditeding', 'avoiding', + 'awaking', 'backing', 'baking', 'balancing', 'baning', 'banging', 'baring', + 'bating', 'bathing', 'battling', 'bing', 'beaming', 'bearing', 'beating', + 'becoming', 'beging', 'begining', 'behaving', 'beholding', 'belonging', + 'bending', 'beseting', 'beting', 'biding', 'binding', 'biting', 'bleaching', + 'bleeding', 'blessing', 'blinding', 'blinking', 'bloting', 'blowing', + 'blushing', 'boasting', 'boiling', 'bolting', 'bombing', 'booking', + 'boring', 'borrowing', 'bouncing', 'bowing', 'boxing', 'braking', + 'branching', 'breaking', 'breathing', 'breeding', 'briefing', 'bringing', + 'broadcasting', 'bruising', 'brushing', 'bubbling', 'budgeting', 'building', + 'bumping', 'burning', 'bursting', 'burying', 'busting', 'buying', 'buzing', + 'calculating', 'calling', 'camping', 'caring', 'carrying', 'carving', + 'casting', 'cataloging', 'catching', 'causing', 'challenging', 'changing', + 'charging', 'charting', 'chasing', 'cheating', 'checking', 'cheering', + 'chewing', 'choking', 'choosing', 'choping', 'claiming', 'claping', + 'clarifying', 'classifying', 'cleaning', 'clearing', 'clinging', 'cliping', + 'closing', 'clothing', 'coaching', 'coiling', 'collecting', 'coloring', + 'combing', 'coming', 'commanding', 'communicating', 'comparing', 'competing', + 'compiling', 'complaining', 'completing', 'composing', 'computing', + 'conceiving', 'concentrating', 'conceptualizing', 'concerning', 'concluding', + 'conducting', 'confessing', 'confronting', 'confusing', 'connecting', + 'conserving', 'considering', 'consisting', 'consolidating', 'constructing', + 'consulting', 'containing', 'continuing', 'contracting', 'controling', + 'converting', 'coordinating', 'copying', 'correcting', 'correlating', + 'costing', 'coughing', 'counseling', 'counting', 'covering', 'cracking', + 'crashing', 'crawling', 'creating', 'creeping', 'critiquing', 'crossing', + 'crushing', 'crying', 'curing', 'curling', 'curving', 'cuting', 'cycling', + 'daming', 'damaging', 'dancing', 'daring', 'dealing', 'decaying', 'deceiving', + 'deciding', 'decorating', 'defining', 'delaying', 'delegating', 'delighting', + 'delivering', 'demonstrating', 'depending', 'describing', 'deserting', + 'deserving', 'designing', 'destroying', 'detailing', 'detecting', + 'determining', 'developing', 'devising', 'diagnosing', 'diging', + 'directing', 'disagreing', 'disappearing', 'disapproving', 'disarming', + 'discovering', 'disliking', 'dispensing', 'displaying', 'disproving', + 'dissecting', 'distributing', 'diving', 'diverting', 'dividing', 'doing', + 'doubling', 'doubting', 'drafting', 'draging', 'draining', 'dramatizing', + 'drawing', 'dreaming', 'dressing', 'drinking', 'driping', 'driving', + 'dropping', 'drowning', 'druming', 'drying', 'dusting', 'dwelling', + 'earning', 'eating', 'editeding', 'educating', 'eliminating', + 'embarrassing', 'employing', 'emptying', 'enacteding', 'encouraging', + 'ending', 'enduring', 'enforcing', 'engineering', 'enhancing', + 'enjoying', 'enlisting', 'ensuring', 'entering', 'entertaining', + 'escaping', 'establishing', 'estimating', 'evaluating', 'examining', + 'exceeding', 'exciting', 'excusing', 'executing', 'exercising', 'exhibiting', + 'existing', 'expanding', 'expecting', 'expediting', 'experimenting', + 'explaining', 'exploding', 'expressing', 'extending', 'extracting', + 'facing', 'facilitating', 'fading', 'failing', 'fancying', 'fastening', + 'faxing', 'fearing', 'feeding', 'feeling', 'fencing', 'fetching', 'fighting', + 'filing', 'filling', 'filming', 'finalizing', 'financing', 'finding', + 'firing', 'fiting', 'fixing', 'flaping', 'flashing', 'fleing', 'flinging', + 'floating', 'flooding', 'flowing', 'flowering', 'flying', 'folding', + 'following', 'fooling', 'forbiding', 'forcing', 'forecasting', 'foregoing', + 'foreseing', 'foretelling', 'forgeting', 'forgiving', 'forming', + 'formulating', 'forsaking', 'framing', 'freezing', 'frightening', 'frying', + 'gathering', 'gazing', 'generating', 'geting', 'giving', 'glowing', 'gluing', + 'going', 'governing', 'grabing', 'graduating', 'grating', 'greasing', 'greeting', + 'grinning', 'grinding', 'griping', 'groaning', 'growing', 'guaranteeing', + 'guarding', 'guessing', 'guiding', 'hammering', 'handing', 'handling', + 'handwriting', 'hanging', 'happening', 'harassing', 'harming', 'hating', + 'haunting', 'heading', 'healing', 'heaping', 'hearing', 'heating', 'helping', + 'hiding', 'hitting', 'holding', 'hooking', 'hoping', 'hopping', 'hovering', + 'hugging', 'hmuming', 'hunting', 'hurrying', 'hurting', 'hypothesizing', + 'identifying', 'ignoring', 'illustrating', 'imagining', 'implementing', + 'impressing', 'improving', 'improvising', 'including', 'increasing', + 'inducing', 'influencing', 'informing', 'initiating', 'injecting', + 'injuring', 'inlaying', 'innovating', 'inputing', 'inspecting', + 'inspiring', 'installing', 'instituting', 'instructing', 'insuring', + 'integrating', 'intending', 'intensifying', 'interesting', + 'interfering', 'interlaying', 'interpreting', 'interrupting', + 'interviewing', 'introducing', 'inventing', 'inventorying', + 'investigating', 'inviting', 'irritating', 'itching', 'jailing', + 'jamming', 'jogging', 'joining', 'joking', 'judging', 'juggling', 'jumping', + 'justifying', 'keeping', 'kepting', 'kicking', 'killing', 'kissing', 'kneeling', + 'kniting', 'knocking', 'knotting', 'knowing', 'labeling', 'landing', 'lasting', + 'laughing', 'launching', 'laying', 'leading', 'leaning', 'leaping', 'learning', + 'leaving', 'lecturing', 'leding', 'lending', 'leting', 'leveling', + 'licensing', 'licking', 'lying', 'lifteding', 'lighting', 'lightening', + 'liking', 'listing', 'listening', 'living', 'loading', 'locating', + 'locking', 'loging', 'longing', 'looking', 'losing', 'loving', + 'maintaining', 'making', 'maning', 'managing', 'manipulating', + 'manufacturing', 'mapping', 'marching', 'marking', 'marketing', + 'marrying', 'matching', 'mating', 'mattering', 'meaning', 'measuring', + 'meddling', 'mediating', 'meeting', 'melting', 'melting', 'memorizing', + 'mending', 'mentoring', 'milking', 'mining', 'misleading', 'missing', + 'misspelling', 'mistaking', 'misunderstanding', 'mixing', 'moaning', + 'modeling', 'modifying', 'monitoring', 'mooring', 'motivating', + 'mourning', 'moving', 'mowing', 'muddling', 'muging', 'multiplying', + 'murdering', 'nailing', 'naming', 'navigating', 'needing', 'negotiating', + 'nesting', 'noding', 'nominating', 'normalizing', 'noting', 'noticing', + 'numbering', 'obeying', 'objecting', 'observing', 'obtaining', 'occuring', + 'offending', 'offering', 'officiating', 'opening', 'operating', 'ordering', + 'organizing', 'orienteding', 'originating', 'overcoming', 'overdoing', + 'overdrawing', 'overflowing', 'overhearing', 'overtaking', 'overthrowing', + 'owing', 'owning', 'packing', 'paddling', 'painting', 'parking', 'parting', + 'participating', 'passing', 'pasting', 'pating', 'pausing', 'paying', + 'pecking', 'pedaling', 'peeling', 'peeping', 'perceiving', 'perfecting', + 'performing', 'permiting', 'persuading', 'phoning', 'photographing', + 'picking', 'piloting', 'pinching', 'pining', 'pinpointing', 'pioneering', + 'placing', 'planing', 'planting', 'playing', 'pleading', 'pleasing', + 'plugging', 'pointing', 'poking', 'polishing', 'poping', 'possessing', + 'posting', 'pouring', 'practicing', 'praiseding', 'praying', 'preaching', + 'preceding', 'predicting', 'prefering', 'preparing', 'prescribing', + 'presenting', 'preserving', 'preseting', 'presiding', 'pressing', + 'pretending', 'preventing', 'pricking', 'printing', 'processing', + 'procuring', 'producing', 'professing', 'programing', 'progressing', + 'projecting', 'promising', 'promoting', 'proofreading', 'proposing', + 'protecting', 'proving', 'providing', 'publicizing', 'pulling', 'pumping', + 'punching', 'puncturing', 'punishing', 'purchasing', 'pushing', 'puting', + 'qualifying', 'questioning', 'queuing', 'quiting', 'racing', 'radiating', + 'raining', 'raising', 'ranking', 'rating', 'reaching', 'reading', + 'realigning', 'realizing', 'reasoning', 'receiving', 'recognizing', + 'recommending', 'reconciling', 'recording', 'recruiting', 'reducing', + 'referring', 'reflecting', 'refusing', 'regreting', 'regulating', + 'rehabilitating', 'reigning', 'reinforcing', 'rejecting', 'rejoicing', + 'relating', 'relaxing', 'releasing', 'relying', 'remaining', 'remembering', + 'reminding', 'removing', 'rendering', 'reorganizing', 'repairing', + 'repeating', 'replacing', 'replying', 'reporting', 'representing', + 'reproducing', 'requesting', 'rescuing', 'researching', 'resolving', + 'responding', 'restoreding', 'restructuring', 'retiring', 'retrieving', + 'returning', 'reviewing', 'revising', 'rhyming', 'riding', 'riding', + 'ringing', 'rinsing', 'rising', 'risking', 'robing', 'rocking', 'rolling', + 'roting', 'rubing', 'ruining', 'ruling', 'runing', 'rushing', 'sacking', + 'sailing', 'satisfying', 'saving', 'sawing', 'saying', 'scaring', + 'scattering', 'scheduling', 'scolding', 'scorching', 'scraping', + 'scratching', 'screaming', 'screwing', 'scribbling', 'scrubing', + 'sealing', 'searching', 'securing', 'seing', 'seeking', 'selecting', + 'selling', 'sending', 'sensing', 'separating', 'serving', 'servicing', + 'seting', 'settling', 'sewing', 'shading', 'shaking', 'shaping', + 'sharing', 'shaving', 'shearing', 'sheding', 'sheltering', 'shining', + 'shivering', 'shocking', 'shoing', 'shooting', 'shoping', 'showing', + 'shrinking', 'shruging', 'shuting', 'sighing', 'signing', 'signaling', + 'simplifying', 'sining', 'singing', 'sinking', 'siping', 'siting', + 'sketching', 'skiing', 'skiping', 'slaping', 'slaying', 'sleeping', + 'sliding', 'slinging', 'slinking', 'sliping', 'sliting', 'slowing', + 'smashing', 'smelling', 'smiling', 'smiting', 'smoking', 'snatching', + 'sneaking', 'sneezing', 'sniffing', 'snoring', 'snowing', 'soaking', + 'solving', 'soothing', 'soothsaying', 'sorting', 'sounding', 'sowing', + 'sparing', 'sparking', 'sparkling', 'speaking', 'specifying', 'speeding', + 'spelling', 'spending', 'spilling', 'spining', 'spiting', 'spliting', + 'spoiling', 'spoting', 'spraying', 'spreading', 'springing', 'sprouting', + 'squashing', 'squeaking', 'squealing', 'squeezing', 'staining', 'stamping', + 'standing', 'staring', 'starting', 'staying', 'stealing', 'steering', + 'stepping', 'sticking', 'stimulating', 'stinging', 'stinking', 'stirring', + 'stitching', 'stoping', 'storing', 'straping', 'streamlining', + 'strengthening', 'stretching', 'striding', 'striking', 'stringing', + 'stripping', 'striving', 'stroking', 'structuring', 'studying', + 'stuffing', 'subleting', 'subtracting', 'succeeding', 'sucking', + 'suffering', 'suggesting', 'suiting', 'summarizing', 'supervising', + 'supplying', 'supporting', 'supposing', 'surprising', 'surrounding', + 'suspecting', 'suspending', 'swearing', 'sweating', 'sweeping', 'swelling', + 'swimming', 'swinging', 'switching', 'symbolizing', 'synthesizing', + 'systemizing', 'tabulating', 'taking', 'talking', 'taming', 'taping', + 'targeting', 'tasting', 'teaching', 'tearing', 'teasing', 'telephoning', + 'telling', 'tempting', 'terrifying', 'testing', 'thanking', 'thawing', + 'thinking', 'thriving', 'throwing', 'thrusting', 'ticking', 'tickling', + 'tying', 'timing', 'tiping', 'tiring', 'touching', 'touring', 'towing', + 'tracing', 'trading', 'training', 'transcribing', 'transfering', + 'transforming', 'translating', 'transporting', 'traping', 'traveling', + 'treading', 'treating', 'trembling', 'tricking', 'triping', 'troting', + 'troubling', 'troubleshooting', 'trusting', 'trying', 'tuging', 'tumbling', + 'turning', 'tutoring', 'twisting', 'typing', 'undergoing', 'understanding', + 'undertaking', 'undressing', 'unfastening', 'unifying', 'uniting', + 'unlocking', 'unpacking', 'untidying', 'updating', 'upgrading', + 'upholding', 'upseting', 'using', 'utilizing', 'vanishing', 'verbalizing', + 'verifying', 'vexing', 'visiting', 'wailing', 'waiting', 'waking', + 'walking', 'wandering', 'wanting', 'warming', 'warning', 'washing', + 'wasting', 'watching', 'watering', 'waving', 'wearing', 'weaving', + 'wedding', 'weeping', 'weighing', 'welcoming', 'wending', 'weting', + 'whining', 'whiping', 'whirling', 'whispering', 'whistling', 'wining', + 'winding', 'winking', 'wiping', 'wishing', 'withdrawing', 'withholding', + 'withstanding', 'wobbling', 'wondering', 'working', 'worrying', 'wrapping', + 'wrecking', 'wrestling', 'wriggling', 'wringing', 'writing', 'x-raying', + 'yawning', 'yelling', 'zipping', 'zooming'] \ No newline at end of file diff --git a/text2image/BigGAN_utils/binary_utils.py b/text2image/BigGAN_utils/binary_utils.py new file mode 100644 index 0000000..04eb4f9 --- /dev/null +++ b/text2image/BigGAN_utils/binary_utils.py @@ -0,0 +1,14 @@ +from torch.autograd import Function +from torch.optim import SGD + + +class BinaryActivation(Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return (x.sign() + 1.) / 2. + + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone() diff --git a/text2image/BigGAN_utils/calculate_inception_moments.py b/text2image/BigGAN_utils/calculate_inception_moments.py new file mode 100644 index 0000000..6ff0967 --- /dev/null +++ b/text2image/BigGAN_utils/calculate_inception_moments.py @@ -0,0 +1,91 @@ +''' Calculate Inception Moments + This script iterates over the dataset and calculates the moments of the + activations of the Inception net (needed for FID), and also returns + the Inception Score of the training data. + + Note that if you don't shuffle the data, the IS of true data will be under- + estimated as it is label-ordered. By default, the data is not shuffled + so as to reduce non-determinism. ''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import utils +import inception_utils +from tqdm import tqdm, trange +from argparse import ArgumentParser + +def prepare_parser(): + usage = 'Calculate and store inception metrics.' + parser = ArgumentParser(description=usage) + parser.add_argument( + '--dataset', type=str, default='I128_hdf5', + help='Which Dataset to train on, out of I128, I256, C10, C100...' + 'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)') + parser.add_argument( + '--data_root', type=str, default='data', + help='Default location where data is stored (default: %(default)s)') + parser.add_argument( + '--batch_size', type=int, default=64, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--parallel', action='store_true', default=False, + help='Train with multiple GPUs (default: %(default)s)') + parser.add_argument( + '--augment', action='store_true', default=False, + help='Augment with random crops and flips (default: %(default)s)') + parser.add_argument( + '--num_workers', type=int, default=8, + help='Number of dataloader workers (default: %(default)s)') + parser.add_argument( + '--shuffle', action='store_true', default=False, + help='Shuffle the data? (default: %(default)s)') + parser.add_argument( + '--seed', type=int, default=0, + help='Random seed to use.') + return parser + +def run(config): + # Get loader + config['drop_last'] = False + loaders = utils.get_data_loaders(**config) + + # Load inception net + net = inception_utils.load_inception_net(parallel=config['parallel']) + pool, logits, labels = [], [], [] + device = 'cuda' + for i, (x, y) in enumerate(tqdm(loaders[0])): + x = x.to(device) + with torch.no_grad(): + pool_val, logits_val = net(x) + pool += [np.asarray(pool_val.cpu())] + logits += [np.asarray(F.softmax(logits_val, 1).cpu())] + labels += [np.asarray(y.cpu())] + + pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]] + # uncomment to save pool, logits, and labels to disk + # print('Saving pool, logits, and labels to disk...') + # np.savez(config['dataset']+'_inception_activations.npz', + # {'pool': pool, 'logits': logits, 'labels': labels}) + # Calculate inception metrics and report them + print('Calculating inception metrics...') + IS_mean, IS_std = inception_utils.calculate_inception_score(logits) + print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std)) + # Prepare mu and sigma, save to disk. Remove "hdf5" by default + # (the FID code also knows to strip "hdf5") + print('Calculating means and covariances...') + mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False) + print('Saving calculated means and covariances to disk...') + np.savez(config['dataset'].strip('_hdf5')+'_inception_moments.npz', **{'mu' : mu, 'sigma' : sigma}) + +def main(): + # parse command line + parser = prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/datasets.py b/text2image/BigGAN_utils/datasets.py new file mode 100644 index 0000000..386bf3b --- /dev/null +++ b/text2image/BigGAN_utils/datasets.py @@ -0,0 +1,362 @@ +''' Datasets + This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets +''' +import os +import os.path +import sys +from PIL import Image +import numpy as np +from tqdm import tqdm, trange + +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torchvision.datasets.utils import download_url, check_integrity +import torch.utils.data as data +from torch.utils.data import DataLoader + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] + + +def is_image_file(filename): + """Checks if a file is an image. + + Args: + filename (string): path to a file + + Returns: + bool: True if the filename ends with a known image extension + """ + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) + + +def find_classes(dir): + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + +def make_dataset(dir, class_to_idx): + images = [] + dir = os.path.expanduser(dir) + for target in tqdm(sorted(os.listdir(dir))): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if is_image_file(fname): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + + return images + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +def accimage_loader(path): + import accimage + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_loader(path): + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) + + +class ImageFolder(data.Dataset): + """A generic data loader where the images are arranged in this way: :: + + root/dogball/xxx.png + root/dogball/xxy.png + root/dogball/xxz.png + + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__(self, root, transform=None, target_transform=None, + loader=default_loader, load_in_mem=False, + index_filename='imagenet_imgs.npz', **kwargs): + classes, class_to_idx = find_classes(root) + # Load pre-computed image directory walk + if os.path.exists(index_filename): + print('Loading pre-saved Index file %s...' % index_filename) + imgs = np.load(index_filename)['imgs'] + # If first time, walk the folder directory and save the + # results to a pre-computed file. + else: + print('Generating Index file %s...' % index_filename) + imgs = make_dataset(root, class_to_idx) + np.savez_compressed(index_filename, **{'imgs' : imgs}) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.classes = classes + self.class_to_idx = class_to_idx + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.load_in_mem = load_in_mem + + if self.load_in_mem: + print('Loading all images into memory...') + self.data, self.labels = [], [] + for index in tqdm(range(len(self.imgs))): + path, target = imgs[index][0], imgs[index][1] + self.data.append(self.transform(self.loader(path))) + self.labels.append(target) + + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + if self.load_in_mem: + img = self.data[index] + target = self.labels[index] + else: + path, target = self.imgs[index] + img = self.loader(str(path)) + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + # print(img.size(), target) + return img, int(target) + + def __len__(self): + return len(self.imgs) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid + having to load individual images all the time. ''' +import h5py as h5 +import torch +class ILSVRC_HDF5(data.Dataset): + def __init__(self, root, transform=None, target_transform=None, + load_in_mem=False, train=True,download=False, validate_seed=0, + val_split=0, **kwargs): # last four are dummies + + self.root = root + self.num_imgs = len(h5.File(root, 'r')['labels']) + + # self.transform = transform + self.target_transform = target_transform + + # Set the transform here + self.transform = transform + + # load the entire dataset into memory? + self.load_in_mem = load_in_mem + + # If loading into memory, do so now + if self.load_in_mem: + print('Loading %s into memory...' % root) + with h5.File(root,'r') as f: + self.data = f['imgs'][:] + self.labels = f['labels'][:] + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + # If loaded the entire dataset in RAM, get image from memory + if self.load_in_mem: + img = self.data[index] + target = self.labels[index] + + # Else load it from disk + else: + with h5.File(self.root,'r') as f: + img = f['imgs'][index] + target = f['labels'][index] + + + # if self.transform is not None: + # img = self.transform(img) + # Apply my own transform + img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, int(target) + + def __len__(self): + return self.num_imgs + # return len(self.f['imgs']) + +import pickle +class CIFAR10(dset.CIFAR10): + + def __init__(self, root, train=True, + transform=None, target_transform=None, + download=True, validate_seed=0, + val_split=0, load_in_mem=True, **kwargs): + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.train = train # training set or test set + self.val_split = val_split + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + # now load the picked numpy arrays + self.data = [] + self.labels= [] + for fentry in self.train_list: + f = fentry[0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + self.data.append(entry['data']) + if 'labels' in entry: + self.labels += entry['labels'] + else: + self.labels += entry['fine_labels'] + fo.close() + + self.data = np.concatenate(self.data) + # Randomly select indices for validation + if self.val_split > 0: + label_indices = [[] for _ in range(max(self.labels)+1)] + for i,l in enumerate(self.labels): + label_indices[l] += [i] + label_indices = np.asarray(label_indices) + + # randomly grab 500 elements of each class + np.random.seed(validate_seed) + self.val_indices = [] + for l_i in label_indices: + self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) + + if self.train=='validate': + self.data = self.data[self.val_indices] + self.labels = list(np.asarray(self.labels)[self.val_indices]) + + self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + elif self.train: + print(np.shape(self.data)) + if self.val_split > 0: + self.data = np.delete(self.data,self.val_indices,axis=0) + self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) + + self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + else: + f = self.test_list[0][0] + file = os.path.join(self.root, self.base_folder, f) + fo = open(file, 'rb') + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') + self.data = entry['data'] + if 'labels' in entry: + self.labels = entry['labels'] + else: + self.labels = entry['fine_labels'] + fo.close() + self.data = self.data.reshape((10000, 3, 32, 32)) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.labels[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + +class CIFAR100(CIFAR10): + base_folder = 'cifar-100-python' + url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] diff --git a/text2image/BigGAN_utils/inception_tf13.py b/text2image/BigGAN_utils/inception_tf13.py new file mode 100644 index 0000000..f43ecac --- /dev/null +++ b/text2image/BigGAN_utils/inception_tf13.py @@ -0,0 +1,138 @@ +''' Tensorflow inception score code +Derived from https://github.com/openai/improved-gan +Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py +THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE + +To use this code, run sample.py on your model with --sample_npz, and then +pass the experiment name in the --experiment_name. +This code also saves pool3 stats to an npz file for FID calculation +''' +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import sys +import tarfile +import math +from tqdm import tqdm, trange +from argparse import ArgumentParser + +import numpy as np +from six.moves import urllib +import tensorflow as tf + +MODEL_DIR = '' +DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' +softmax = None + +def prepare_parser(): + usage = 'Parser for TF1.3- Inception Score scripts.' + parser = ArgumentParser(description=usage) + parser.add_argument( + '--experiment_name', type=str, default='', + help='Which experiment''s samples.npz file to pull and evaluate') + parser.add_argument( + '--experiment_root', type=str, default='samples', + help='Default location where samples are stored (default: %(default)s)') + parser.add_argument( + '--batch_size', type=int, default=500, + help='Default overall batchsize (default: %(default)s)') + return parser + + +def run(config): + # Inception with TF1.3 or earlier. + # Call this function with list of images. Each of elements should be a + # numpy array with values ranging from 0 to 255. + def get_inception_score(images, splits=10): + assert(type(images) == list) + assert(type(images[0]) == np.ndarray) + assert(len(images[0].shape) == 3) + assert(np.max(images[0]) > 10) + assert(np.min(images[0]) >= 0.0) + inps = [] + for img in images: + img = img.astype(np.float32) + inps.append(np.expand_dims(img, 0)) + bs = config['batch_size'] + with tf.Session() as sess: + preds, pools = [], [] + n_batches = int(math.ceil(float(len(inps)) / float(bs))) + for i in trange(n_batches): + inp = inps[(i * bs):min((i + 1) * bs, len(inps))] + inp = np.concatenate(inp, 0) + pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp}) + preds.append(pred) + pools.append(pool) + preds = np.concatenate(preds, 0) + scores = [] + for i in range(splits): + part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0)) + # Init inception + def _init_inception(): + global softmax, pool3 + if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(MODEL_DIR, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % ( + filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') + tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) + with tf.gfile.FastGFile(os.path.join( + MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + _ = tf.import_graph_def(graph_def, name='') + # Works with an arbitrary minibatch size. + with tf.Session() as sess: + pool3 = sess.graph.get_tensor_by_name('pool_3:0') + ops = pool3.graph.get_operations() + for op_idx, op in enumerate(ops): + for o in op.outputs: + shape = o.get_shape() + shape = [s.value for s in shape] + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o._shape = tf.TensorShape(new_shape) + w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] + logits = tf.matmul(tf.squeeze(pool3), w) + softmax = tf.nn.softmax(logits) + + # if softmax is None: # No need to functionalize like this. + _init_inception() + + fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name']) + print('loading %s ...'%fname) + ims = np.load(fname)['x'] + import time + t0 = time.time() + inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10) + t1 = time.time() + print('Saving pool to numpy file for FID calculations...') + np.savez('%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name']), **{'pool_mean': np.mean(pool_activations,axis=0), 'pool_var': np.cov(pool_activations, rowvar=False)}) + print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std)) +def main(): + # parse command line and run + parser = prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/inception_utils.py b/text2image/BigGAN_utils/inception_utils.py new file mode 100644 index 0000000..373d3cd --- /dev/null +++ b/text2image/BigGAN_utils/inception_utils.py @@ -0,0 +1,310 @@ +''' Inception utilities + This file contains methods for calculating IS and FID, using either + the original numpy code or an accelerated fully-pytorch version that + uses a fast newton-schulz approximation for the matrix sqrt. There are also + methods for acquiring a desired number of samples from the Generator, + and parallelizing the inbuilt PyTorch inception network. + + NOTE that Inception Scores and FIDs calculated using these methods will + *not* be directly comparable to values calculated using the original TF + IS/FID code. You *must* use the TF model if you wish to report and compare + numbers. This code tends to produce IS values that are 5-10% lower than + those obtained through TF. +''' +import numpy as np +from scipy import linalg # For numpy FID +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter as P +from torchvision.models.inception import inception_v3 + + +# Module that wraps the inception network to enable use with dataparallel and +# returning pool features and logits. +class WrapInception(nn.Module): + def __init__(self, net): + super(WrapInception,self).__init__() + self.net = net + self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), + requires_grad=False) + self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), + requires_grad=False) + def forward(self, x): + # Normalize x + x = (x + 1.) / 2.0 + x = (x - self.mean) / self.std + # Upsample if necessary + if x.shape[2] != 299 or x.shape[3] != 299: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) + # 299 x 299 x 3 + x = self.net.Conv2d_1a_3x3(x) + # 149 x 149 x 32 + x = self.net.Conv2d_2a_3x3(x) + # 147 x 147 x 32 + x = self.net.Conv2d_2b_3x3(x) + # 147 x 147 x 64 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 73 x 73 x 64 + x = self.net.Conv2d_3b_1x1(x) + # 73 x 73 x 80 + x = self.net.Conv2d_4a_3x3(x) + # 71 x 71 x 192 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 35 x 35 x 192 + x = self.net.Mixed_5b(x) + # 35 x 35 x 256 + x = self.net.Mixed_5c(x) + # 35 x 35 x 288 + x = self.net.Mixed_5d(x) + # 35 x 35 x 288 + x = self.net.Mixed_6a(x) + # 17 x 17 x 768 + x = self.net.Mixed_6b(x) + # 17 x 17 x 768 + x = self.net.Mixed_6c(x) + # 17 x 17 x 768 + x = self.net.Mixed_6d(x) + # 17 x 17 x 768 + x = self.net.Mixed_6e(x) + # 17 x 17 x 768 + # 17 x 17 x 768 + x = self.net.Mixed_7a(x) + # 8 x 8 x 1280 + x = self.net.Mixed_7b(x) + # 8 x 8 x 2048 + x = self.net.Mixed_7c(x) + # 8 x 8 x 2048 + pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2) + # 1 x 1 x 2048 + logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1)) + # 1000 (num_classes) + return pool, logits + + +# A pytorch implementation of cov, from Modar M. Alfadly +# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 +def torch_cov(m, rowvar=False): + '''Estimate a covariance matrix given data. + + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. + + Returns: + The covariance matrix of the variables. + ''' + if m.dim() > 2: + raise ValueError('m has more than 2 dimensions') + if m.dim() < 2: + m = m.view(1, -1) + if not rowvar and m.size(0) != 1: + m = m.t() + # m = m.type(torch.double) # uncomment this line if desired + fact = 1.0 / (m.size(1) - 1) + m -= torch.mean(m, dim=1, keepdim=True) + mt = m.t() # if complex: mt = m.t().conj() + return fact * m.matmul(mt).squeeze() + + +# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji +# https://github.com/msubhransu/matrix-sqrt +def sqrt_newton_schulz(A, numIters, dtype=None): + with torch.no_grad(): + if dtype is None: + dtype = A.type() + batchSize = A.shape[0] + dim = A.shape[1] + normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() + Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); + I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + for i in range(numIters): + T = 0.5*(3.0*I - Z.bmm(Y)) + Y = Y.bmm(T) + Z = T.bmm(Z) + sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) + return sA + + +# FID calculator from TTUR--consider replacing this with GPU-accelerated cov +# calculations using torch? +def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + Taken from https://github.com/bioinf-jku/TTUR + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + print('wat') + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + return out + + +def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Pytorch implementation of the Frechet Distance. + Taken from https://github.com/bioinf-jku/TTUR + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 + covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() + out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) + - 2 * torch.trace(covmean)) + return out + + +# Calculate Inception Score mean + std given softmax'd logits and number of splits +def calculate_inception_score(pred, num_splits=10): + scores = [] + for index in range(num_splits): + pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :] + kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))) + kl_inception = np.mean(np.sum(kl_inception, 1)) + scores.append(np.exp(kl_inception)) + return np.mean(scores), np.std(scores) + + +# Loop and run the sampler and the net until it accumulates num_inception_images +# activations. Return the pool, the logits, and the labels (if one wants +# Inception Accuracy the labels of the generated class will be needed) +def accumulate_inception_activations(sample, net, num_inception_images=50000): + pool, logits, labels = [], [], [] + while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images: + with torch.no_grad(): + images, labels_val = sample() + pool_val, logits_val = net(images.float()) + pool += [pool_val] + logits += [F.softmax(logits_val, 1)] + labels += [labels_val] + return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0) + + +# Load and wrap the Inception model +def load_inception_net(parallel=False): + inception_model = inception_v3(pretrained=True, transform_input=False) + inception_model = WrapInception(inception_model.eval()).cuda() + if parallel: + print('Parallelizing Inception module...') + inception_model = nn.DataParallel(inception_model) + return inception_model + + +# This produces a function which takes in an iterator which returns a set number of samples +# and iterates until it accumulates config['num_inception_images'] images. +# The iterator can return samples with a different batch size than used in +# training, using the setting confg['inception_batchsize'] +def prepare_inception_metrics(dataset, parallel, no_fid=False): + # Load metrics; this is intentionally not in a try-except loop so that + # the script will crash here if it cannot find the Inception moments. + # By default, remove the "hdf5" from dataset + dataset = dataset.strip('_hdf5') + data_mu = np.load(dataset+'_inception_moments.npz')['mu'] + data_sigma = np.load(dataset+'_inception_moments.npz')['sigma'] + # Load network + net = load_inception_net(parallel) + def get_inception_metrics(sample, num_inception_images, num_splits=10, + prints=True, use_torch=True): + if prints: + print('Gathering activations...') + pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images) + if prints: + print('Calculating Inception Score...') + IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits) + if no_fid: + FID = 9999.0 + else: + if prints: + print('Calculating means and covariances...') + if use_torch: + mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False) + else: + mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False) + if prints: + print('Covariances calculated, getting FID...') + if use_torch: + FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda()) + FID = float(FID.cpu().numpy()) + else: + FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma) + # Delete mu, sigma, pool, logits, and labels, just in case + del mu, sigma, pool, logits, labels + return IS_mean, IS_std, FID + return get_inception_metrics \ No newline at end of file diff --git a/text2image/BigGAN_utils/layers.py b/text2image/BigGAN_utils/layers.py new file mode 100644 index 0000000..55aaab1 --- /dev/null +++ b/text2image/BigGAN_utils/layers.py @@ -0,0 +1,459 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d + + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2,2]) + g = F.max_pool2d(self.g(x), [2,2]) + # Perform reshapes + theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m **2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn',): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s +=' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size= output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1,-1,1,1) + bias = self.bias.view(1,-1,1,1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv, self.which_bn = which_conv, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.out_channels) + self.conv2 = self.which_conv(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None,): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + +# dogball \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl b/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl new file mode 100644 index 0000000..cfa578f --- /dev/null +++ b/text2image/BigGAN_utils/logs/BigGAN_ch96_bs256x8.jsonl @@ -0,0 +1,68 @@ +{"itr": 2000, "IS_mean": 2.806771755218506, "IS_std": 0.019480662420392036, "FID": 173.76484159711126, "_stamp": 1551403232.0425167} +{"itr": 4000, "IS_mean": 4.962374687194824, "IS_std": 0.07276841998100281, "FID": 113.86730514283107, "_stamp": 1551422228.743057} +{"itr": 6000, "IS_mean": 6.939817905426025, "IS_std": 0.11417163163423538, "FID": 101.63548498447199, "_stamp": 1551457139.3400874} +{"itr": 8000, "IS_mean": 8.142985343933105, "IS_std": 0.11931543797254562, "FID": 92.0014385772705, "_stamp": 1551476217.2409613} +{"itr": 10000, "IS_mean": 10.355518341064453, "IS_std": 0.09094739705324173, "FID": 83.58068997965364, "_stamp": 1551494854.2419689} +{"itr": 12000, "IS_mean": 11.288347244262695, "IS_std": 0.14952820539474487, "FID": 80.98066299357106, "_stamp": 1551513232.5049698} +{"itr": 14000, "IS_mean": 11.755794525146484, "IS_std": 0.17969024181365967, "FID": 76.80603924280956, "_stamp": 1551531425.150371} +{"itr": 18000, "IS_mean": 13.65534496307373, "IS_std": 0.11151058971881866, "FID": 65.95736694335938, "_stamp": 1551588271.9177916} +{"itr": 20000, "IS_mean": 14.817827224731445, "IS_std": 0.23588882386684418, "FID": 61.32061767578125, "_stamp": 1551606713.6567464} +{"itr": 22000, "IS_mean": 17.16551399230957, "IS_std": 0.19506946206092834, "FID": 53.387969970703125, "_stamp": 1551624876.6513028} +{"itr": 24000, "IS_mean": 19.60654067993164, "IS_std": 0.5591856837272644, "FID": 46.5386962890625, "_stamp": 1551642822.6126688} +{"itr": 26000, "IS_mean": 21.74416732788086, "IS_std": 0.2850531041622162, "FID": 41.595001220703125, "_stamp": 1551663522.6019194} +{"itr": 28000, "IS_mean": 23.923612594604492, "IS_std": 0.41587772965431213, "FID": 37.894744873046875, "_stamp": 1551681794.6567173} +{"itr": 30000, "IS_mean": 25.569377899169922, "IS_std": 0.3333457112312317, "FID": 35.49310302734375, "_stamp": 1551699773.7080302} +{"itr": 32000, "IS_mean": 26.867944717407227, "IS_std": 0.5968036651611328, "FID": 33.4849853515625, "_stamp": 1551717623.887933} +{"itr": 34000, "IS_mean": 28.719074249267578, "IS_std": 0.5698027014732361, "FID": 31.375518798828125, "_stamp": 1551735411.1578612} +{"itr": 36000, "IS_mean": 30.587574005126953, "IS_std": 0.5044271349906921, "FID": 29.432281494140625, "_stamp": 1551783380.6357439} +{"itr": 38000, "IS_mean": 32.08299255371094, "IS_std": 0.49342143535614014, "FID": 28.099456787109375, "_stamp": 1551801179.6495197} +{"itr": 40000, "IS_mean": 34.24657440185547, "IS_std": 0.7709177732467651, "FID": 26.53802490234375, "_stamp": 1551818775.171794} +{"itr": 42000, "IS_mean": 35.891212463378906, "IS_std": 0.7036871314048767, "FID": 25.03021240234375, "_stamp": 1551836329.6873965} +{"itr": 44000, "IS_mean": 38.184898376464844, "IS_std": 0.32996198534965515, "FID": 23.4940185546875, "_stamp": 1551897864.911537} +{"itr": 46000, "IS_mean": 40.239479064941406, "IS_std": 0.7761151194572449, "FID": 22.53167724609375, "_stamp": 1551915406.4840703} +{"itr": 48000, "IS_mean": 41.46656036376953, "IS_std": 1.1031498908996582, "FID": 21.5338134765625, "_stamp": 1551932899.6074848} +{"itr": 50000, "IS_mean": 43.31670379638672, "IS_std": 0.7796809077262878, "FID": 20.53253173828125, "_stamp": 1551950390.345334} +{"itr": 52000, "IS_mean": 45.1517333984375, "IS_std": 1.2925242185592651, "FID": 19.656646728515625, "_stamp": 1551967838.1501615} +{"itr": 54000, "IS_mean": 47.638771057128906, "IS_std": 1.0689665079116821, "FID": 18.898162841796875, "_stamp": 1552044534.5349634} +{"itr": 56000, "IS_mean": 48.87520217895508, "IS_std": 1.1317559480667114, "FID": 18.1248779296875, "_stamp": 1552061763.3080354} +{"itr": 58000, "IS_mean": 49.40987014770508, "IS_std": 1.1866596937179565, "FID": 17.751922607421875, "_stamp": 1552078939.9828825} +{"itr": 60000, "IS_mean": 51.051334381103516, "IS_std": 1.2281248569488525, "FID": 17.19964599609375, "_stamp": 1552096167.889482} +{"itr": 62000, "IS_mean": 52.0235481262207, "IS_std": 0.5391153693199158, "FID": 16.62115478515625, "_stamp": 1552113417.9520617} +{"itr": 64000, "IS_mean": 53.868492126464844, "IS_std": 1.327082633972168, "FID": 16.237335205078125, "_stamp": 1552142961.09602} +{"itr": 66000, "IS_mean": 54.978721618652344, "IS_std": 0.9502049088478088, "FID": 15.81170654296875, "_stamp": 1552162403.2232807} +{"itr": 68000, "IS_mean": 55.73248291015625, "IS_std": 1.0323851108551025, "FID": 15.545623779296875, "_stamp": 1552181112.676657} +{"itr": 70000, "IS_mean": 56.78422927856445, "IS_std": 1.211003303527832, "FID": 15.28369140625, "_stamp": 1552199498.887533} +{"itr": 72000, "IS_mean": 57.972999572753906, "IS_std": 0.8668608665466309, "FID": 14.86395263671875, "_stamp": 1552217782.2738616} +{"itr": 74000, "IS_mean": 58.845054626464844, "IS_std": 1.4297977685928345, "FID": 14.620635986328125, "_stamp": 1552251085.1781816} +{"itr": 76000, "IS_mean": 59.60982131958008, "IS_std": 0.9095696210861206, "FID": 14.360198974609375, "_stamp": 1552270214.9345307} +{"itr": 78000, "IS_mean": 60.71195602416992, "IS_std": 0.960899829864502, "FID": 14.07183837890625, "_stamp": 1552288697.1580262} +{"itr": 80000, "IS_mean": 61.772125244140625, "IS_std": 0.6913255453109741, "FID": 13.781585693359375, "_stamp": 1552307170.0280282} +{"itr": 82000, "IS_mean": 62.98079299926758, "IS_std": 1.4735801219940186, "FID": 13.55389404296875, "_stamp": 1552325252.8553352} +{"itr": 84000, "IS_mean": 64.95240783691406, "IS_std": 0.9018951654434204, "FID": 13.231689453125, "_stamp": 1552344135.3111835} +{"itr": 86000, "IS_mean": 65.13968658447266, "IS_std": 0.8772205114364624, "FID": 13.176849365234375, "_stamp": 1552362429.6782444} +{"itr": 88000, "IS_mean": 65.84476470947266, "IS_std": 1.167534351348877, "FID": 12.87078857421875, "_stamp": 1552380560.7988124} +{"itr": 90000, "IS_mean": 67.41099548339844, "IS_std": 1.6899267435073853, "FID": 12.586517333984375, "_stamp": 1552398550.2060475} +{"itr": 92000, "IS_mean": 68.63685607910156, "IS_std": 1.9431978464126587, "FID": 12.49505615234375, "_stamp": 1552430781.6406457} +{"itr": 94000, "IS_mean": 70.09907531738281, "IS_std": 1.0715738534927368, "FID": 12.047607421875, "_stamp": 1552449001.1950285} +{"itr": 96000, "IS_mean": 70.34623718261719, "IS_std": 1.7962944507598877, "FID": 11.896697998046875, "_stamp": 1552466989.3587568} +{"itr": 98000, "IS_mean": 71.08210754394531, "IS_std": 1.458209753036499, "FID": 11.73046875, "_stamp": 1552484800.7138846} +{"itr": 100000, "IS_mean": 72.24256896972656, "IS_std": 1.3259714841842651, "FID": 11.7386474609375, "_stamp": 1552502538.0269725} +{"itr": 102000, "IS_mean": 73.19488525390625, "IS_std": 1.3439149856567383, "FID": 11.50494384765625, "_stamp": 1552523284.4514356} +{"itr": 104000, "IS_mean": 73.38243103027344, "IS_std": 1.4162707328796387, "FID": 11.374542236328125, "_stamp": 1552541012.0651608} +{"itr": 106000, "IS_mean": 74.95563507080078, "IS_std": 1.089124083518982, "FID": 11.10479736328125, "_stamp": 1552558577.7458107} +{"itr": 108000, "IS_mean": 76.42997741699219, "IS_std": 1.9282453060150146, "FID": 10.998870849609375, "_stamp": 1552576111.9480467} +{"itr": 110000, "IS_mean": 76.89225769042969, "IS_std": 1.4771150350570679, "FID": 10.847015380859375, "_stamp": 1552593659.445132} +{"itr": 112000, "IS_mean": 78.04684448242188, "IS_std": 1.4850096702575684, "FID": 10.772552490234375, "_stamp": 1552616479.5201895} +{"itr": 114000, "IS_mean": 79.67677307128906, "IS_std": 2.0147368907928467, "FID": 10.528045654296875, "_stamp": 1552633850.9315467} +{"itr": 116000, "IS_mean": 79.8828125, "IS_std": 0.978247344493866, "FID": 10.626068115234375, "_stamp": 1552651198.9012825} +{"itr": 118000, "IS_mean": 79.95381164550781, "IS_std": 1.8608143329620361, "FID": 10.46771240234375, "_stamp": 1552668560.4420238} +{"itr": 120000, "IS_mean": 82.37217712402344, "IS_std": 1.8909310102462769, "FID": 10.259033203125, "_stamp": 1552749673.4319007} +{"itr": 122000, "IS_mean": 83.49666595458984, "IS_std": 2.38446044921875, "FID": 9.996185302734375, "_stamp": 1552766698.2706933} +{"itr": 124000, "IS_mean": 83.05189514160156, "IS_std": 1.8844469785690308, "FID": 10.164398193359375, "_stamp": 1552783762.891172} +{"itr": 126000, "IS_mean": 84.27763366699219, "IS_std": 0.9329544901847839, "FID": 10.03509521484375, "_stamp": 1552800953.5724175} +{"itr": 128000, "IS_mean": 85.84852600097656, "IS_std": 2.2698562145233154, "FID": 9.91644287109375, "_stamp": 1552818112.227726} +{"itr": 130000, "IS_mean": 87.356689453125, "IS_std": 2.0958640575408936, "FID": 9.771148681640625, "_stamp": 1552837539.995247} +{"itr": 132000, "IS_mean": 88.72562408447266, "IS_std": 1.7551432847976685, "FID": 9.8258056640625, "_stamp": 1552859685.9305944} +{"itr": 134000, "IS_mean": 88.0631103515625, "IS_std": 1.8199039697647095, "FID": 9.957183837890625, "_stamp": 1552880037.5408435} +{"itr": 136000, "IS_mean": 91.50938415527344, "IS_std": 1.9926033020019531, "FID": 9.876556396484375, "_stamp": 1552899854.652669} +{"itr": 138000, "IS_mean": 93.09217834472656, "IS_std": 2.3062736988067627, "FID": 9.908477783203125, "_stamp": 1552921580.958927} \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/compare_IS.m b/text2image/BigGAN_utils/logs/compare_IS.m new file mode 100644 index 0000000..c72b079 --- /dev/null +++ b/text2image/BigGAN_utils/logs/compare_IS.m @@ -0,0 +1,89 @@ +clc +clear all +close all +fclose all; + + + +%% Get All logs and sort them +s = {}; +d = dir(); +j = 1; +for i = 1:length(d) + if any(strfind(d(i).name,'.jsonl')) + s = [s; d(i).name]; + end +end + + +j = 1; +for i = 1:length(s) + fname = s{i,1}; + % Check if the Inception metrics log exists, and if so, plot it + [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl'); + s{i,2} = itr; + s{i,3} = IS; + s{i,4} = FID; + s{i,5} = max(IS); + s{i,6} = min(FID); + s{i,7} = t; +end +% Sort by Inception Score? +[IS_sorted, IS_index] = sort(cell2mat(s(:,5))); +% Cutoff inception scores below a certain value? +threshold = 22; +IS_index = IS_index(IS_sorted > threshold); + +% Sort by FID? +[FID_sorted, FID_index] = sort(cell2mat(s(:,6))); +% Cutoff also based on IS? +% threshold = 0; +FID_index = FID_index(IS_sorted > threshold); + + + +%% Plot things? +cc = hsv(length(IS_index)); +legend1 = {}; +legend2 = {}; +make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations +for i=1:length(IS_index) + legend1 = [legend1; s{IS_index(i), 1}]; + figure(1) + plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2) + hold on; + xlabel('itr'); ylabel('IS'); + grid on; + if make_axis + axis([0,1e6,0,80]); % 50% grid on; + end + legend(legend1,'Interpreter','none') + %pause(1) % Turn this on to animate stuff + legend2 = [legend2; s{IS_index(i), 1}]; + figure(2) + plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2) + hold on; + xlabel('itr'); ylabel('FID'); + j = j + 1; + grid on; + if make_axis + axis([0,1e6,0,50]);% grid on; + end + legend(legend2, 'Interpreter','none') + +end + +%% Quick script to plot IS versus timesteps +if 0 + figure(3); + this_index=4; + subplot(2,1,1); + %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*'); + % xlabel('Iteration');ylabel('\Delta T') + plot(s{this_index, 2}, s{this_index, 7}, 'r*'); + xlabel('Iteration');ylabel('T') + subplot(2,1,2); + plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2); + xlabel('Iteration'), ylabel('Inception score') + title(s{this_index,1}) +end \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/metalog.txt b/text2image/BigGAN_utils/logs/metalog.txt new file mode 100644 index 0000000..9214b1b --- /dev/null +++ b/text2image/BigGAN_utils/logs/metalog.txt @@ -0,0 +1,3 @@ +datetime: 2019-03-18 13:27:59.181225 +config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)} +state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}} diff --git a/text2image/BigGAN_utils/logs/process_inception_log.m b/text2image/BigGAN_utils/logs/process_inception_log.m new file mode 100644 index 0000000..42e7480 --- /dev/null +++ b/text2image/BigGAN_utils/logs/process_inception_log.m @@ -0,0 +1,19 @@ +function [itr, IS, FID, t] = process_inception_log(fname, which_log) +f = sprintf('%s_%s',fname, which_log);%'G_loss.log'); +fid = fopen(f,'r'); +itr = []; +IS = []; +FID = []; +t = []; +i = 1; +while ~feof(fid); + s = fgets(fid); + parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}'); + itr(i) = parsed(1); + IS(i) = parsed(2); + FID(i) = parsed(4); + t(i) = parsed(5); + i = i + 1; +end +fclose(fid); +end \ No newline at end of file diff --git a/text2image/BigGAN_utils/logs/process_training.m b/text2image/BigGAN_utils/logs/process_training.m new file mode 100644 index 0000000..bfc0788 --- /dev/null +++ b/text2image/BigGAN_utils/logs/process_training.m @@ -0,0 +1,109 @@ +clc +clear all +close all +fclose all; + + + +%% Get all training logs for a given run +target_dir = '.'; +s = {}; +nm = {}; +d = dir(target_dir); +j = 1; +for i = 1:length(d) + if any(strfind(d(i).name,'.log')) + s = [s; sprintf('%s\\%s', target_dir, d(i).name)]; + nm = [nm; d(i).name]; + end +end +%% Loop over training logs and acquire data +D_count = 0; +G_count = 0; +for i = 1:length(s) + fname = s{i,1}; + fid = fopen(s{i,1},'r'); + % Prepare bookkeeping for sv0 + if any(strfind(s{i,1},'sv')) + if any(strfind(s{i,1},'G_')) + G_count = G_count +1; + else + D_count = D_count + 1; + end + end + itr = []; + val = []; + j = 1; + while ~feof(fid); + line = fgets(fid); + parsed = sscanf(line, '%d: %e'); + itr(j) = parsed(1); + val(j) = parsed(2); + j = j + 1; + end + s{i,2} = itr; + s{i,3} = val; + fclose(fid); +end + +%% Plot SVs and losses +close all; +Gcc = hsv(G_count); +Dcc = hsv(D_count); +gi = 1; +di = 1; +li = 1; +legendG = {}; +legendD = {}; +legendL = {}; +thresh=2; % wavelet denoising threshold +losses = {}; +for i=1:length(s) + if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log')) + % Select colors + if any(strfind(s{i,1},'D_loss_real.log')) + color1 = [0.7,0.7,1.0]; + color2 = [0, 0, 1]; + dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; + losses = [losses; dlr]; + elseif any(strfind(s{i,1},'D_loss_fake.log')) + color1 = [0.7,1.0,0.7]; + color2 = [0, 1, 0]; + dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; + losses = [losses; dlf]; + else % g loss + color1 = [1.0, 0.7,0.7]; + color2 = [1, 0, 0]; + gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2}; + losses = [losses; gl]; + end + figure(1); hold on; + % Plot the unsmoothed losses; we'll plot the smoothed losses later + plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off'); + legendL = [legendL; nm{i}]; + continue + end + if any(strfind(s{i,1},'G_')) + legendG = [legendG; nm{i}]; + figure(2); hold on; + plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2); + gi = gi+1; + elseif any(strfind(s{i,1},'D_')) + legendD = [legendD; nm{i}]; + figure(3); hold on; + plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2); + di = di+1; + else + s{i,1} % Debug print to show the name of the log that was not processed. + end +end +figure(1); +% Plot the smoothed losses last +for i = 1:3 +% plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off'); +plot(losses{i,1},losses{i,3},'color',losses{i,5}); +end +legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]); + +figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0'); +figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0'); diff --git a/text2image/BigGAN_utils/losses.py b/text2image/BigGAN_utils/losses.py new file mode 100644 index 0000000..6a7467f --- /dev/null +++ b/text2image/BigGAN_utils/losses.py @@ -0,0 +1,33 @@ +import torch +import torch.nn.functional as F + +# DCGAN loss +def loss_dcgan_dis(dis_fake, dis_real): + L1 = torch.mean(F.softplus(-dis_real)) + L2 = torch.mean(F.softplus(dis_fake)) + return L1, L2 + + +def loss_dcgan_gen(dis_fake): + loss = torch.mean(F.softplus(-dis_fake)) + return loss + + +# Hinge Loss +def loss_hinge_dis(dis_fake, dis_real): + loss_real = torch.mean(F.relu(1. - dis_real)) + loss_fake = torch.mean(F.relu(1. + dis_fake)) + return loss_real, loss_fake +# def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss + # loss = torch.mean(F.relu(1. - dis_real)) + # loss += torch.mean(F.relu(1. + dis_fake)) + # return loss + + +def loss_hinge_gen(dis_fake): + loss = -torch.mean(dis_fake) + return loss + +# Default to hinge loss +generator_loss = loss_hinge_gen +discriminator_loss = loss_hinge_dis \ No newline at end of file diff --git a/text2image/BigGAN_utils/make_hdf5.py b/text2image/BigGAN_utils/make_hdf5.py new file mode 100644 index 0000000..b21d323 --- /dev/null +++ b/text2image/BigGAN_utils/make_hdf5.py @@ -0,0 +1,110 @@ +""" Convert dataset to HDF5 + This script preprocesses a dataset and saves it (images and labels) to + an HDF5 file for improved I/O. """ +import os +import sys +from argparse import ArgumentParser +from tqdm import tqdm, trange +import h5py as h5 + +import numpy as np +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torchvision.utils import save_image +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +import utils + +def prepare_parser(): + usage = 'Parser for ImageNet HDF5 scripts.' + parser = ArgumentParser(description=usage) + parser.add_argument( + '--dataset', type=str, default='I128', + help='Which Dataset to train on, out of I128, I256, C10, C100;' + 'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)') + parser.add_argument( + '--data_root', type=str, default='data', + help='Default location where data is stored (default: %(default)s)') + parser.add_argument( + '--batch_size', type=int, default=256, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--num_workers', type=int, default=16, + help='Number of dataloader workers (default: %(default)s)') + parser.add_argument( + '--chunk_size', type=int, default=500, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--compression', action='store_true', default=False, + help='Use LZF compression? (default: %(default)s)') + return parser + + +def run(config): + if 'hdf5' in config['dataset']: + raise ValueError('Reading from an HDF5 file which you will probably be ' + 'about to overwrite! Override this error only if you know ' + 'what you''re doing!') + # Get image size + config['image_size'] = utils.imsize_dict[config['dataset']] + + # Update compression entry + config['compression'] = 'lzf' if config['compression'] else None #No compression; can also use 'lzf' + + # Get dataset + kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False} + train_loader = utils.get_data_loaders(dataset=config['dataset'], + batch_size=config['batch_size'], + shuffle=False, + data_root=config['data_root'], + use_multiepoch_sampler=False, + **kwargs)[0] + + # HDF5 supports chunking and compression. You may want to experiment + # with different chunk sizes to see how it runs on your machines. + # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128 + # 1 / None 20/s + # 500 / None ramps up to 77/s 102/s 61GB 23min + # 500 / LZF 8/s 56GB 23min + # 1000 / None 78/s + # 5000 / None 81/s + # auto:(125,1,16,32) / None 11/s 61GB + + print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (config['dataset'], config['chunk_size'], config['compression'])) + # Loop over train loader + for i,(x,y) in enumerate(tqdm(train_loader)): + # Stick X into the range [0, 255] since it's coming from the train loader + x = (255 * ((x + 1) / 2.0)).byte().numpy() + # Numpyify y + y = y.numpy() + # If we're on the first batch, prepare the hdf5 + if i==0: + with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f: + print('Producing dataset of len %d' % len(train_loader.dataset)) + imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']), + chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression']) + print('Image chunks chosen as ' + str(imgs_dset.chunks)) + imgs_dset[...] = x + labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),), chunks=(config['chunk_size'],), compression=config['compression']) + print('Label chunks chosen as ' + str(labels_dset.chunks)) + labels_dset[...] = y + # Else append to the hdf5 + else: + with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f: + f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) + f['imgs'][-x.shape[0]:] = x + f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) + f['labels'][-y.shape[0]:] = y + + +def main(): + # parse command line and run + parser = prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/sample.py b/text2image/BigGAN_utils/sample.py new file mode 100644 index 0000000..663880d --- /dev/null +++ b/text2image/BigGAN_utils/sample.py @@ -0,0 +1,183 @@ +''' Sample + This script loads a pretrained net and a weightsfile and sample ''' +import functools +import math +import numpy as np +from tqdm import tqdm, trange + + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P +import torchvision + +# Import my stuff +import inception_utils +import utils +import losses + + + +def run(config): + # Prepare state dict, which holds things like epoch # and itr # + state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, + 'best_IS': 0, 'best_FID': 999999, 'config': config} + + # Optionally, get the configuration from the state dict. This allows for + # recovery of the config provided only a state dict and experiment name, + # and can be convenient for writing less verbose sample shell scripts. + if config['config_from_name']: + utils.load_weights(None, None, state_dict, config['weights_root'], + config['experiment_name'], config['load_weights'], None, + strict=False, load_optim=False) + # Ignore items which we might want to overwrite from the command line + for item in state_dict['config']: + if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: + config[item] = state_dict['config'][item] + + # update config (see train.py for explanation) + config['resolution'] = utils.imsize_dict[config['dataset']] + config['n_classes'] = utils.nclass_dict[config['dataset']] + config['G_activation'] = utils.activation_dict[config['G_nl']] + config['D_activation'] = utils.activation_dict[config['D_nl']] + config = utils.update_config_roots(config) + config['skip_init'] = True + config['no_optim'] = True + device = 'cuda' + + # Seed RNG + utils.seed_rng(config['seed']) + + # Setup cudnn.benchmark for free speed + torch.backends.cudnn.benchmark = True + + # Import the model--this line allows us to dynamically select different files. + model = __import__(config['model']) + experiment_name = (config['experiment_name'] if config['experiment_name'] + else utils.name_from_config(config)) + print('Experiment name is %s' % experiment_name) + + G = model.Generator(**config).cuda() + utils.count_parameters(G) + + # Load weights + print('Loading weights...') + # Here is where we deal with the ema--load ema weights or load normal weights + utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, + config['weights_root'], experiment_name, config['load_weights'], + G if config['ema'] and config['use_ema'] else None, + strict=False, load_optim=False) + # Update batch size setting used for G + G_batch_size = max(config['G_batch_size'], config['batch_size']) + z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], + device=device, fp16=config['G_fp16'], + z_var=config['z_var']) + + if config['G_eval_mode']: + print('Putting G in eval mode..') + G.eval() + else: + print('G is in %s mode...' % ('training' if G.training else 'eval')) + + #Sample function + sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) + if config['accumulate_stats']: + print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) + utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], + config['num_standing_accumulations']) + + + # Sample a number of images and save them to an NPZ, for use with TF-Inception + if config['sample_npz']: + # Lists to hold images and labels for images + x, y = [], [] + print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) + for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): + with torch.no_grad(): + images, labels = sample() + x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] + y += [labels.cpu().numpy()] + x = np.concatenate(x, 0)[:config['sample_num_npz']] + y = np.concatenate(y, 0)[:config['sample_num_npz']] + print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) + npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) + print('Saving npz to %s...' % npz_filename) + np.savez(npz_filename, **{'x' : x, 'y' : y}) + + # Prepare sample sheets + if config['sample_sheets']: + print('Preparing conditional sample sheets...') + utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], + num_classes=config['n_classes'], + samples_per_class=10, parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=config['sample_sheet_folder_num'], + z_=z_,) + # Sample interp sheets + if config['sample_interps']: + print('Preparing interp sheets...') + for fix_z, fix_y in zip([False, False, True], [False, True, False]): + utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, + num_classes=config['n_classes'], + parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=config['sample_sheet_folder_num'], + sheet_number=0, + fix_z=fix_z, fix_y=fix_y, device='cuda') + # Sample random sheet + if config['sample_random']: + print('Preparing random sample sheet...') + images, labels = sample() + torchvision.utils.save_image(images.float(), + '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), + nrow=int(G_batch_size**0.5), + normalize=True) + + # Get Inception Score and FID + get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) + # Prepare a simple function get metrics that we use for trunc curves + def get_metrics(): + sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) + IS_mean, IS_std, FID = get_inception_metrics(sample, config['num_inception_images'], num_splits=10, prints=False) + # Prepare output string + outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') + outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') + outstring += 'with noise variance %3.3f, ' % z_.var + outstring += 'over %d images, ' % config['num_inception_images'] + if config['accumulate_stats'] or not config['G_eval_mode']: + outstring += 'with batch size %d, ' % G_batch_size + if config['accumulate_stats']: + outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations'] + outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID) + print(outstring) + if config['sample_inception_metrics']: + print('Calculating Inception metrics...') + get_metrics() + + # Sample truncation curve stuff. This is basically the same as the inception metrics code + if config['sample_trunc_curves']: + start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')] + print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) + for var in np.arange(start, end + step, step): + z_.var = var + # Optionally comment this out if you want to run with standing stats + # accumulated at one z variance setting + if config['accumulate_stats']: + utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], + config['num_standing_accumulations']) + get_metrics() +def main(): + # parse command line and run + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh new file mode 100644 index 0000000..0c4831c --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs256x8.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 96 --D_ch 96 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh new file mode 100644 index 0000000..114023f --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_bs512x4.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 512 --load_in_mem \ +--num_G_accumulations 4 --num_D_accumulations 4 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 96 --D_ch 96 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh new file mode 100644 index 0000000..650d2fc --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_ch64_bs256x8.sh @@ -0,0 +1,17 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 120 --shared_dim 128 \ +--G_eval_mode \ +--G_ch 64 --G_ch 64 \ +--ema --use_ema --ema_start 20000 \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh b/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh new file mode 100644 index 0000000..5e83ef1 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_BigGAN_deep.sh @@ -0,0 +1,18 @@ +#!/bin/bash +python train.py \ +--model BigGANdeep \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_ch 128 --D_ch 128 \ +--G_depth 2 --D_depth 2 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho \ +--hier --dim_z 128 --shared_dim 128 \ +--ema --use_ema --ema_start 20000 --G_eval_mode \ +--test_every 2000 --save_every 500 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--use_multiepoch_sampler \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh b/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh new file mode 100644 index 0000000..0e99a0e --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_SAGAN_bs128x2_ema.sh @@ -0,0 +1,13 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 128 \ +--num_G_accumulations 2 --num_D_accumulations 2 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_nl relu --D_nl relu \ +--SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ +--G_ortho 0.0 \ +--G_init xavier --D_init xavier \ +--ema --use_ema --ema_start 2000 --G_eval_mode \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--name_suffix SAGAN_ema \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_SNGAN.sh b/text2image/BigGAN_utils/scripts/launch_SNGAN.sh new file mode 100644 index 0000000..a2c9d66 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_SNGAN.sh @@ -0,0 +1,14 @@ +#!/bin/bash +python train.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 64 \ +--num_G_accumulations 1 --num_D_accumulations 1 \ +--num_D_steps 5 --G_lr 2e-4 --D_lr 2e-4 --D_B2 0.900 --G_B2 0.900 \ +--G_attn 0 --D_attn 0 \ +--G_nl relu --D_nl relu \ +--SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ +--G_ortho 0.0 \ +--D_thin \ +--G_init xavier --D_init xavier \ + --G_eval_mode \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--name_suffix SNGAN \ \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh b/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh new file mode 100644 index 0000000..07e15c2 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/launch_cifar_ema.sh @@ -0,0 +1,11 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 python train.py \ +--shuffle --batch_size 50 --parallel \ +--num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ +--num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ +--dataset C10 \ +--G_ortho 0.0 \ +--G_attn 0 --D_attn 0 \ +--G_init N02 --D_init N02 \ +--ema --use_ema --ema_start 1000 \ +--test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh b/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh new file mode 100644 index 0000000..c228da6 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/sample_BigGAN_bs256x8.sh @@ -0,0 +1,20 @@ +# use z_var to change the variance of z for all the sampling +# use --mybn --accumulate_stats --num_standing_accumulations 32 to +# use running stats +python sample.py \ +--dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ +--num_G_accumulations 8 --num_D_accumulations 8 \ +--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ +--G_attn 64 --D_attn 64 \ +--G_ch 96 --D_ch 96 \ +--G_nl inplace_relu --D_nl inplace_relu \ +--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ +--G_ortho 0.0 \ +--G_shared \ +--G_init ortho --D_init ortho --skip_init \ +--hier --dim_z 120 --shared_dim 128 \ +--ema --ema_start 20000 \ +--use_multiepoch_sampler \ +--test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ +--skip_init --G_batch_size 512 --use_ema --G_eval_mode --sample_trunc_curves 0.05_0.05_1.0 \ +--sample_inception_metrics --sample_npz --sample_random --sample_sheets --sample_interps diff --git a/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh b/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh new file mode 100644 index 0000000..e6db0a1 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/sample_cifar_ema.sh @@ -0,0 +1,11 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 python sample.py \ +--shuffle --batch_size 50 --G_batch_size 256 --parallel \ +--num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ +--num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ +--dataset C10 \ +--G_ortho 0.0 \ +--G_attn 0 --D_attn 0 \ +--G_init N02 --D_init N02 \ +--ema --use_ema --ema_start 1000 \ +--test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/utils/duplicate.sh b/text2image/BigGAN_utils/scripts/utils/duplicate.sh new file mode 100644 index 0000000..06da14a --- /dev/null +++ b/text2image/BigGAN_utils/scripts/utils/duplicate.sh @@ -0,0 +1,14 @@ +#duplicate.sh +source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0 +target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A +logs_root=logs +weights_root=weights +echo "copying ${source} to ${target}" +cp -r ${logs_root}/${source} ${logs_root}/${target} +cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl +cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth +cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth +cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth +cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth +cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth +cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth \ No newline at end of file diff --git a/text2image/BigGAN_utils/scripts/utils/prepare_data.sh b/text2image/BigGAN_utils/scripts/utils/prepare_data.sh new file mode 100644 index 0000000..e75c660 --- /dev/null +++ b/text2image/BigGAN_utils/scripts/utils/prepare_data.sh @@ -0,0 +1,3 @@ +#!/bin/bash +python make_hdf5.py --dataset I128 --batch_size 256 --data_root data +python calculate_inception_moments.py --dataset I128_hdf5 --data_root data \ No newline at end of file diff --git a/text2image/BigGAN_utils/sync_batchnorm/__init__.py b/text2image/BigGAN_utils/sync_batchnorm/__init__.py new file mode 100644 index 0000000..bc8709d --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..a3fb1ab Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..78a61f6 Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/__init__.cpython-38.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc new file mode 100644 index 0000000..aee8b6a Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc new file mode 100644 index 0000000..86d8d62 Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc new file mode 100644 index 0000000..05ea0d9 Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc new file mode 100644 index 0000000..0143006 Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/comm.cpython-38.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc new file mode 100644 index 0000000..eb083c2 Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-37.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc new file mode 100644 index 0000000..f88646f Binary files /dev/null and b/text2image/BigGAN_utils/sync_batchnorm/__pycache__/replicate.cpython-38.pyc differ diff --git a/text2image/BigGAN_utils/sync_batchnorm/batchnorm.py b/text2image/BigGAN_utils/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..5453729 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/batchnorm.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) +# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input, gain=None, bias=None): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + out = F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + if gain is not None: + out = out + gain + if bias is not None: + out = out + bias + return out + + # Resize the input to (B, C, -1). + input_shape = input.size() + # print(input_shape) + input = input.view(input.size(0), input.size(1), -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + # Reduce-and-broadcast the statistics. + # print('it begins') + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + # if self._parallel_id == 0: + # # print('here') + # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + # else: + # # print('there') + # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # print('how2') + # num = sum_size + # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) + # Fix the graph + # sum = (sum.detach() - input_sum.detach()) + input_sum + # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum + + # mean = sum / num + # var = ssum / num - mean ** 2 + # # var = (ssum - mean * sum) / num + # inv_std = torch.rsqrt(var + self.eps) + + # Compute the output. + if gain is not None: + # print('gaining') + # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) + # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) + # output = input * scale - shift + output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) + elif self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + # print('a') + # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) + # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) + # print('b') + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + return mean, torch.rsqrt(bias_var + self.eps) + # return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file diff --git a/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py b/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000..7afcdaf --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNormReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/text2image/BigGAN_utils/sync_batchnorm/comm.py b/text2image/BigGAN_utils/sync_batchnorm/comm.py new file mode 100644 index 0000000..922f8c4 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/text2image/BigGAN_utils/sync_batchnorm/replicate.py b/text2image/BigGAN_utils/sync_batchnorm/replicate.py new file mode 100644 index 0000000..b71c7b8 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/text2image/BigGAN_utils/sync_batchnorm/unittest.py b/text2image/BigGAN_utils/sync_batchnorm/unittest.py new file mode 100644 index 0000000..bed56f1 --- /dev/null +++ b/text2image/BigGAN_utils/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/text2image/BigGAN_utils/train.py b/text2image/BigGAN_utils/train.py new file mode 100644 index 0000000..4fb804c --- /dev/null +++ b/text2image/BigGAN_utils/train.py @@ -0,0 +1,227 @@ +""" BigGAN: The Authorized Unofficial PyTorch release + Code by A. Brock and A. Andonian + This code is an unofficial reimplementation of + "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," + by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). + + Let's go. +""" + +import os +import functools +import math +import numpy as np +from tqdm import tqdm, trange + + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P +import torchvision + +# Import my stuff +import inception_utils +import utils +import losses +import train_fns +from sync_batchnorm import patch_replication_callback + +# The main training file. Config is a dictionary specifying the configuration +# of this training run. +def run(config): + + # Update the config dict as necessary + # This is for convenience, to add settings derived from the user-specified + # configuration into the config-dict (e.g. inferring the number of classes + # and size of the images from the dataset, passing in a pytorch object + # for the activation specified as a string) + config['resolution'] = utils.imsize_dict[config['dataset']] + config['n_classes'] = utils.nclass_dict[config['dataset']] + config['G_activation'] = utils.activation_dict[config['G_nl']] + config['D_activation'] = utils.activation_dict[config['D_nl']] + # By default, skip init if resuming training. + if config['resume']: + print('Skipping initialization for training resumption...') + config['skip_init'] = True + config = utils.update_config_roots(config) + device = 'cuda' + + # Seed RNG + utils.seed_rng(config['seed']) + + # Prepare root folders if necessary + utils.prepare_root(config) + + # Setup cudnn.benchmark for free speed + torch.backends.cudnn.benchmark = True + + # Import the model--this line allows us to dynamically select different files. + model = __import__(config['model']) + experiment_name = (config['experiment_name'] if config['experiment_name'] + else utils.name_from_config(config)) + print('Experiment name is %s' % experiment_name) + + # Next, build the model + G = model.Generator(**config).to(device) + D = model.Discriminator(**config).to(device) + + # If using EMA, prepare it + if config['ema']: + print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) + G_ema = model.Generator(**{**config, 'skip_init':True, + 'no_optim': True}).to(device) + ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) + else: + G_ema, ema = None, None + + # FP16? + if config['G_fp16']: + print('Casting G to float16...') + G = G.half() + if config['ema']: + G_ema = G_ema.half() + if config['D_fp16']: + print('Casting D to fp16...') + D = D.half() + # Consider automatically reducing SN_eps? + GD = model.G_D(G, D) + print(G) + print(D) + print('Number of params in G: {} D: {}'.format( + *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]])) + # Prepare state dict, which holds things like epoch # and itr # + state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, + 'best_IS': 0, 'best_FID': 999999, 'config': config} + + # If loading from a pre-trained model, load weights + if config['resume']: + print('Loading weights...') + utils.load_weights(G, D, state_dict, + config['weights_root'], experiment_name, + config['load_weights'] if config['load_weights'] else None, + G_ema if config['ema'] else None) + + # If parallel, parallelize the GD module + if config['parallel']: + GD = nn.DataParallel(GD) + if config['cross_replica']: + patch_replication_callback(GD) + + # Prepare loggers for stats; metrics holds test metrics, + # lmetrics holds any desired training metrics. + test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], + experiment_name) + train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) + print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) + test_log = utils.MetricsLogger(test_metrics_fname, + reinitialize=(not config['resume'])) + print('Training Metrics will be saved to {}'.format(train_metrics_fname)) + train_log = utils.MyLogger(train_metrics_fname, + reinitialize=(not config['resume']), + logstyle=config['logstyle']) + # Write metadata + utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) + # Prepare data; the Discriminator's batch size is all that needs to be passed + # to the dataloader, as G doesn't require dataloading. + # Note that at every loader iteration we pass in enough data to complete + # a full D iteration (regardless of number of D steps and accumulations) + D_batch_size = (config['batch_size'] * config['num_D_steps'] + * config['num_D_accumulations']) + loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, + 'start_itr': state_dict['itr']}) + + # Prepare inception metrics: FID and IS + get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) + + # Prepare noise and randomly sampled label arrays + # Allow for different batch sizes in G + G_batch_size = max(config['G_batch_size'], config['batch_size']) + z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], + device=device, fp16=config['G_fp16']) + # Prepare a fixed z & y to see individual sample evolution throghout training + fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, + config['n_classes'], device=device, + fp16=config['G_fp16']) + fixed_z.sample_() + fixed_y.sample_() + # Loaders are loaded, prepare the training function + if config['which_train_fn'] == 'GAN': + train = train_fns.GAN_training_function(G, D, GD, z_, y_, + ema, state_dict, config) + # Else, assume debugging and use the dummy train fn + else: + train = train_fns.dummy_training_function() + # Prepare Sample function for use with inception metrics + sample = functools.partial(utils.sample, + G=(G_ema if config['ema'] and config['use_ema'] + else G), + z_=z_, y_=y_, config=config) + + print('Beginning training at epoch %d...' % state_dict['epoch']) + # Train for specified number of epochs, although we mostly track G iterations. + for epoch in range(state_dict['epoch'], config['num_epochs']): + # Which progressbar to use? TQDM or my own? + if config['pbar'] == 'mine': + pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') + else: + pbar = tqdm(loaders[0]) + for i, (x, y) in enumerate(pbar): + # Increment the iteration counter + state_dict['itr'] += 1 + # Make sure G and D are in training mode, just in case they got set to eval + # For D, which typically doesn't have BN, this shouldn't matter much. + G.train() + D.train() + if config['ema']: + G_ema.train() + if config['D_fp16']: + x, y = x.to(device).half(), y.to(device) + else: + x, y = x.to(device), y.to(device) + metrics = train(x, y) + train_log.log(itr=int(state_dict['itr']), **metrics) + + # Every sv_log_interval, log singular values + if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): + train_log.log(itr=int(state_dict['itr']), + **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) + + # If using my progbar, print metrics. + if config['pbar'] == 'mine': + print(', '.join(['itr: %d' % state_dict['itr']] + + ['%s : %+4.3f' % (key, metrics[key]) + for key in metrics]), end=' ') + + # Save weights and copies as configured at specified interval + if not (state_dict['itr'] % config['save_every']): + if config['G_eval_mode']: + print('Switchin G to eval mode...') + G.eval() + if config['ema']: + G_ema.eval() + train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, + state_dict, config, experiment_name) + + # Test every specified interval + if not (state_dict['itr'] % config['test_every']): + if config['G_eval_mode']: + print('Switchin G to eval mode...') + G.eval() + train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, + get_inception_metrics, experiment_name, test_log) + # Increment epoch counter at end of epoch + state_dict['epoch'] += 1 + + +def main(): + # parse command line and run + parser = utils.prepare_parser() + config = vars(parser.parse_args()) + print(config) + run(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/text2image/BigGAN_utils/train_fns.py b/text2image/BigGAN_utils/train_fns.py new file mode 100644 index 0000000..e03e530 --- /dev/null +++ b/text2image/BigGAN_utils/train_fns.py @@ -0,0 +1,187 @@ +''' train_fns.py +Functions for the main loop of training different conditional image models +''' +import torch +import torch.nn as nn +import torchvision +import os + +import utils +import losses + + +# Dummy training function for debugging +def dummy_training_function(): + def train(x, y): + return {} + return train + + +def GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config): + def train(x, y): + G.optim.zero_grad() + D.optim.zero_grad() + # How many chunks to split x and y into? + x = torch.split(x, config['batch_size']) + y = torch.split(y, config['batch_size']) + counter = 0 + + # Optionally toggle D and G's "require_grad" + if config['toggle_grads']: + utils.toggle_grad(D, True) + utils.toggle_grad(G, False) + + for step_index in range(config['num_D_steps']): + # If accumulating gradients, loop multiple times before an optimizer step + D.optim.zero_grad() + for accumulation_index in range(config['num_D_accumulations']): + z_.sample_() + y_.sample_() + D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], + x[counter], y[counter], train_G=False, + split_D=config['split_D']) + + # Compute components of D's loss, average them, and divide by + # the number of gradient accumulations + D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) + D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) + D_loss.backward() + counter += 1 + + # Optionally apply ortho reg in D + if config['D_ortho'] > 0.0: + # Debug print to indicate we're using ortho reg in D. + print('using modified ortho reg in D') + utils.ortho(D, config['D_ortho']) + + D.optim.step() + + # Optionally toggle "requires_grad" + if config['toggle_grads']: + utils.toggle_grad(D, False) + utils.toggle_grad(G, True) + + # Zero G's gradients by default before training G, for safety + G.optim.zero_grad() + + # If accumulating gradients, loop multiple times + for accumulation_index in range(config['num_G_accumulations']): + z_.sample_() + y_.sample_() + D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) + G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) + G_loss.backward() + + # Optionally apply modified ortho reg in G + if config['G_ortho'] > 0.0: + print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G + # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this + utils.ortho(G, config['G_ortho'], + blacklist=[param for param in G.shared.parameters()]) + G.optim.step() + + # If we have an ema, update it, regardless of if we test with it or not + if config['ema']: + ema.update(state_dict['itr']) + + out = {'G_loss': float(G_loss.item()), + 'D_loss_real': float(D_loss_real.item()), + 'D_loss_fake': float(D_loss_fake.item())} + # Return G's loss and the components of D's loss. + return out + return train + +''' This function takes in the model, saves the weights (multiple copies if + requested), and prepares sample sheets: one consisting of samples given + a fixed noise seed (to show how the model evolves throughout training), + a set of full conditional sample sheets, and a set of interp sheets. ''' +def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, + state_dict, config, experiment_name): + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, None, G_ema if config['ema'] else None) + # Save an additional copy to mitigate accidental corruption if process + # is killed during a save (it's happened to me before -.-) + if config['num_save_copies'] > 0: + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, + 'copy%d' % state_dict['save_num'], + G_ema if config['ema'] else None) + state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies'] + + # Use EMA G for samples or non-EMA? + which_G = G_ema if config['ema'] and config['use_ema'] else G + + # Accumulate standing statistics? + if config['accumulate_stats']: + utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, + z_, y_, config['n_classes'], + config['num_standing_accumulations']) + + # Save a random sample sheet with fixed z and y + with torch.no_grad(): + if config['parallel']: + fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y))) + else: + fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) + if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): + os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) + image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], + experiment_name, + state_dict['itr']) + torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, ## NOTE: xcliu for torchvision 0.8.2 + nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) + #torchvision.utils.save_image(torch.from_numpy(fixed_Gz.float().cpu().numpy()), image_filename, + # nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) + + # For now, every time we save, also save sample sheets + utils.sample_sheet(which_G, + classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], + num_classes=config['n_classes'], + samples_per_class=10, parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=state_dict['itr'], + z_=z_) + # Also save interp sheets + for fix_z, fix_y in zip([False, False, True], [False, True, False]): + utils.interp_sheet(which_G, + num_per_sheet=16, + num_midpoints=8, + num_classes=config['n_classes'], + parallel=config['parallel'], + samples_root=config['samples_root'], + experiment_name=experiment_name, + folder_number=state_dict['itr'], + sheet_number=0, + fix_z=fix_z, fix_y=fix_y, device='cuda') + + + +''' This function runs the inception metrics code, checks if the results + are an improvement over the previous best (either in IS or FID, + user-specified), logs the results, and saves a best_ copy if it's an + improvement. ''' +def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, + experiment_name, test_log): + print('Gathering inception metrics...') + if config['accumulate_stats']: + utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, + z_, y_, config['n_classes'], + config['num_standing_accumulations']) + IS_mean, IS_std, FID = get_inception_metrics(sample, + config['num_inception_images'], + num_splits=10) + print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID)) + # If improved over previous best metric, save approrpiate copy + if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) + or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])): + print('%s improved over previous best, saving checkpoint...' % config['which_best']) + utils.save_weights(G, D, state_dict, config['weights_root'], + experiment_name, 'best%d' % state_dict['save_best_num'], + G_ema if config['ema'] else None) + state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies'] + state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean) + state_dict['best_FID'] = min(state_dict['best_FID'], FID) + # Log results to file + test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean), + IS_std=float(IS_std), FID=float(FID)) diff --git a/text2image/BigGAN_utils/utils.py b/text2image/BigGAN_utils/utils.py new file mode 100644 index 0000000..9d79f33 --- /dev/null +++ b/text2image/BigGAN_utils/utils.py @@ -0,0 +1,1195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +''' Utilities file +This file contains utility functions for bookkeeping, logging, and data loading. +Methods which directly affect training should either go in layers, the model, +or train_fns.py. +''' + +from __future__ import print_function +import sys +import os +import numpy as np +import time +import datetime +import json +import pickle +from argparse import ArgumentParser +import animal_hash +import datasets as dset + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +def prepare_parser(): + usage = 'Parser for all scripts.' + parser = ArgumentParser(description=usage) + + ### Dataset/Dataloader stuff ### + parser.add_argument( + '--dataset', type=str, default='I128_hdf5', + help='Which Dataset to train on, out of I128, I256, C10, C100;' + 'Append "_hdf5" to use the hdf5 version for ISLVRC ' + '(default: %(default)s)') + parser.add_argument( + '--augment', action='store_true', default=False, + help='Augment with random crops and flips (default: %(default)s)') + parser.add_argument( + '--num_workers', type=int, default=8, + help='Number of dataloader workers; consider using less for HDF5 ' + '(default: %(default)s)') + parser.add_argument( + '--no_pin_memory', action='store_false', dest='pin_memory', default=True, + help='Pin data into memory through dataloader? (default: %(default)s)') + parser.add_argument( + '--shuffle', action='store_true', default=False, + help='Shuffle the data (strongly recommended)? (default: %(default)s)') + parser.add_argument( + '--load_in_mem', action='store_true', default=False, + help='Load all data into memory? (default: %(default)s)') + parser.add_argument( + '--use_multiepoch_sampler', action='store_true', default=False, + help='Use the multi-epoch sampler for dataloader? (default: %(default)s)') + + + ### Model stuff ### + parser.add_argument( + '--model', type=str, default='BigGAN', + help='Name of the model module (default: %(default)s)') + parser.add_argument( + '--G_param', type=str, default='SN', + help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)' + ' or None (default: %(default)s)') + parser.add_argument( + '--D_param', type=str, default='SN', + help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)' + ' or None (default: %(default)s)') + parser.add_argument( + '--G_ch', type=int, default=64, + help='Channel multiplier for G (default: %(default)s)') + parser.add_argument( + '--D_ch', type=int, default=64, + help='Channel multiplier for D (default: %(default)s)') + parser.add_argument( + '--G_depth', type=int, default=1, + help='Number of resblocks per stage in G? (default: %(default)s)') + parser.add_argument( + '--D_depth', type=int, default=1, + help='Number of resblocks per stage in D? (default: %(default)s)') + parser.add_argument( + '--D_thin', action='store_false', dest='D_wide', default=True, + help='Use the SN-GAN channel pattern for D? (default: %(default)s)') + parser.add_argument( + '--G_shared', action='store_true', default=False, + help='Use shared embeddings in G? (default: %(default)s)') + parser.add_argument( + '--shared_dim', type=int, default=0, + help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. ' + '(default: %(default)s)') + parser.add_argument( + '--dim_z', type=int, default=128, + help='Noise dimensionality: %(default)s)') + parser.add_argument( + '--z_var', type=float, default=1.0, + help='Noise variance: %(default)s)') + parser.add_argument( + '--hier', action='store_true', default=False, + help='Use hierarchical z in G? (default: %(default)s)') + parser.add_argument( + '--cross_replica', action='store_true', default=False, + help='Cross_replica batchnorm in G?(default: %(default)s)') + parser.add_argument( + '--mybn', action='store_true', default=False, + help='Use my batchnorm (which supports standing stats?) %(default)s)') + parser.add_argument( + '--G_nl', type=str, default='relu', + help='Activation function for G (default: %(default)s)') + parser.add_argument( + '--D_nl', type=str, default='relu', + help='Activation function for D (default: %(default)s)') + parser.add_argument( + '--G_attn', type=str, default='64', + help='What resolutions to use attention on for G (underscore separated) ' + '(default: %(default)s)') + parser.add_argument( + '--D_attn', type=str, default='64', + help='What resolutions to use attention on for D (underscore separated) ' + '(default: %(default)s)') + parser.add_argument( + '--norm_style', type=str, default='bn', + help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], ' + 'ln [layernorm], gn [groupnorm] (default: %(default)s)') + + ### Model init stuff ### + parser.add_argument( + '--seed', type=int, default=0, + help='Random seed to use; affects both initialization and ' + ' dataloading. (default: %(default)s)') + parser.add_argument( + '--G_init', type=str, default='ortho', + help='Init style to use for G (default: %(default)s)') + parser.add_argument( + '--D_init', type=str, default='ortho', + help='Init style to use for D(default: %(default)s)') + parser.add_argument( + '--skip_init', action='store_true', default=False, + help='Skip initialization, ideal for testing when ortho init was used ' + '(default: %(default)s)') + + ### Optimizer stuff ### + parser.add_argument( + '--G_lr', type=float, default=5e-5, + help='Learning rate to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_lr', type=float, default=2e-4, + help='Learning rate to use for Discriminator (default: %(default)s)') + parser.add_argument( + '--G_B1', type=float, default=0.0, + help='Beta1 to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_B1', type=float, default=0.0, + help='Beta1 to use for Discriminator (default: %(default)s)') + parser.add_argument( + '--G_B2', type=float, default=0.999, + help='Beta2 to use for Generator (default: %(default)s)') + parser.add_argument( + '--D_B2', type=float, default=0.999, + help='Beta2 to use for Discriminator (default: %(default)s)') + + ### Batch size, parallel, and precision stuff ### + parser.add_argument( + '--batch_size', type=int, default=64, + help='Default overall batchsize (default: %(default)s)') + parser.add_argument( + '--G_batch_size', type=int, default=0, + help='Batch size to use for G; if 0, same as D (default: %(default)s)') + parser.add_argument( + '--num_G_accumulations', type=int, default=1, + help='Number of passes to accumulate G''s gradients over ' + '(default: %(default)s)') + parser.add_argument( + '--num_D_steps', type=int, default=2, + help='Number of D steps per G step (default: %(default)s)') + parser.add_argument( + '--num_D_accumulations', type=int, default=1, + help='Number of passes to accumulate D''s gradients over ' + '(default: %(default)s)') + parser.add_argument( + '--split_D', action='store_true', default=False, + help='Run D twice rather than concatenating inputs? (default: %(default)s)') + parser.add_argument( + '--num_epochs', type=int, default=100, + help='Number of epochs to train for (default: %(default)s)') + parser.add_argument( + '--parallel', action='store_true', default=False, + help='Train with multiple GPUs (default: %(default)s)') + parser.add_argument( + '--G_fp16', action='store_true', default=False, + help='Train with half-precision in G? (default: %(default)s)') + parser.add_argument( + '--D_fp16', action='store_true', default=False, + help='Train with half-precision in D? (default: %(default)s)') + parser.add_argument( + '--D_mixed_precision', action='store_true', default=False, + help='Train with half-precision activations but fp32 params in D? ' + '(default: %(default)s)') + parser.add_argument( + '--G_mixed_precision', action='store_true', default=False, + help='Train with half-precision activations but fp32 params in G? ' + '(default: %(default)s)') + parser.add_argument( + '--accumulate_stats', action='store_true', default=False, + help='Accumulate "standing" batchnorm stats? (default: %(default)s)') + parser.add_argument( + '--num_standing_accumulations', type=int, default=16, + help='Number of forward passes to use in accumulating standing stats? ' + '(default: %(default)s)') + + ### Bookkeping stuff ### + parser.add_argument( + '--G_eval_mode', action='store_true', default=False, + help='Run G in eval mode (running/standing stats?) at sample/test time? ' + '(default: %(default)s)') + parser.add_argument( + '--save_every', type=int, default=2000, + help='Save every X iterations (default: %(default)s)') + parser.add_argument( + '--num_save_copies', type=int, default=2, + help='How many copies to save (default: %(default)s)') + parser.add_argument( + '--num_best_copies', type=int, default=2, + help='How many previous best checkpoints to save (default: %(default)s)') + parser.add_argument( + '--which_best', type=str, default='IS', + help='Which metric to use to determine when to save new "best"' + 'checkpoints, one of IS or FID (default: %(default)s)') + parser.add_argument( + '--no_fid', action='store_true', default=False, + help='Calculate IS only, not FID? (default: %(default)s)') + parser.add_argument( + '--test_every', type=int, default=5000, + help='Test every X iterations (default: %(default)s)') + parser.add_argument( + '--num_inception_images', type=int, default=50000, + help='Number of samples to compute inception metrics with ' + '(default: %(default)s)') + parser.add_argument( + '--hashname', action='store_true', default=False, + help='Use a hash of the experiment name instead of the full config ' + '(default: %(default)s)') + parser.add_argument( + '--base_root', type=str, default='', + help='Default location to store all weights, samples, data, and logs ' + ' (default: %(default)s)') + parser.add_argument( + '--data_root', type=str, default='data', + help='Default location where data is stored (default: %(default)s)') + parser.add_argument( + '--weights_root', type=str, default='weights', + help='Default location to store weights (default: %(default)s)') + parser.add_argument( + '--logs_root', type=str, default='logs', + help='Default location to store logs (default: %(default)s)') + parser.add_argument( + '--samples_root', type=str, default='samples', + help='Default location to store samples (default: %(default)s)') + parser.add_argument( + '--pbar', type=str, default='mine', + help='Type of progressbar to use; one of "mine" or "tqdm" ' + '(default: %(default)s)') + parser.add_argument( + '--name_suffix', type=str, default='', + help='Suffix for experiment name for loading weights for sampling ' + '(consider "best0") (default: %(default)s)') + parser.add_argument( + '--experiment_name', type=str, default='', + help='Optionally override the automatic experiment naming with this arg. ' + '(default: %(default)s)') + parser.add_argument( + '--config_from_name', action='store_true', default=False, + help='Use a hash of the experiment name instead of the full config ' + '(default: %(default)s)') + + ### EMA Stuff ### + parser.add_argument( + '--ema', action='store_true', default=False, + help='Keep an ema of G''s weights? (default: %(default)s)') + parser.add_argument( + '--ema_decay', type=float, default=0.9999, + help='EMA decay rate (default: %(default)s)') + parser.add_argument( + '--use_ema', action='store_true', default=False, + help='Use the EMA parameters of G for evaluation? (default: %(default)s)') + parser.add_argument( + '--ema_start', type=int, default=0, + help='When to start updating the EMA weights (default: %(default)s)') + + ### Numerical precision and SV stuff ### + parser.add_argument( + '--adam_eps', type=float, default=1e-8, + help='epsilon value to use for Adam (default: %(default)s)') + parser.add_argument( + '--BN_eps', type=float, default=1e-5, + help='epsilon value to use for BatchNorm (default: %(default)s)') + parser.add_argument( + '--SN_eps', type=float, default=1e-8, + help='epsilon value to use for Spectral Norm(default: %(default)s)') + parser.add_argument( + '--num_G_SVs', type=int, default=1, + help='Number of SVs to track in G (default: %(default)s)') + parser.add_argument( + '--num_D_SVs', type=int, default=1, + help='Number of SVs to track in D (default: %(default)s)') + parser.add_argument( + '--num_G_SV_itrs', type=int, default=1, + help='Number of SV itrs in G (default: %(default)s)') + parser.add_argument( + '--num_D_SV_itrs', type=int, default=1, + help='Number of SV itrs in D (default: %(default)s)') + + ### Ortho reg stuff ### + parser.add_argument( + '--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN + help='Modified ortho reg coefficient in G(default: %(default)s)') + parser.add_argument( + '--D_ortho', type=float, default=0.0, + help='Modified ortho reg coefficient in D (default: %(default)s)') + parser.add_argument( + '--toggle_grads', action='store_true', default=True, + help='Toggle D and G''s "requires_grad" settings when not training them? ' + ' (default: %(default)s)') + + ### Which train function ### + parser.add_argument( + '--which_train_fn', type=str, default='GAN', + help='How2trainyourbois (default: %(default)s)') + + ### Resume training stuff + parser.add_argument( + '--load_weights', type=str, default='', + help='Suffix for which weights to load (e.g. best0, copy0) ' + '(default: %(default)s)') + parser.add_argument( + '--resume', action='store_true', default=False, + help='Resume training? (default: %(default)s)') + + ### Log stuff ### + parser.add_argument( + '--logstyle', type=str, default='%3.3e', + help='What style to use when logging training metrics?' + 'One of: %#.#f/ %#.#e (float/exp, text),' + 'pickle (python pickle),' + 'npz (numpy zip),' + 'mat (MATLAB .mat file) (default: %(default)s)') + parser.add_argument( + '--log_G_spectra', action='store_true', default=False, + help='Log the top 3 singular values in each SN layer in G? ' + '(default: %(default)s)') + parser.add_argument( + '--log_D_spectra', action='store_true', default=False, + help='Log the top 3 singular values in each SN layer in D? ' + '(default: %(default)s)') + parser.add_argument( + '--sv_log_interval', type=int, default=10, + help='Iteration interval for logging singular values ' + ' (default: %(default)s)') + + parser.add_argument('--text', type=str) + + return parser + +# Arguments for sample.py; not presently used in train.py +def add_sample_parser(parser): + parser.add_argument( + '--sample_npz', action='store_true', default=False, + help='Sample "sample_num_npz" images and save to npz? ' + '(default: %(default)s)') + parser.add_argument( + '--sample_num_npz', type=int, default=50000, + help='Number of images to sample when sampling NPZs ' + '(default: %(default)s)') + parser.add_argument( + '--sample_sheets', action='store_true', default=False, + help='Produce class-conditional sample sheets and stick them in ' + 'the samples root? (default: %(default)s)') + parser.add_argument( + '--sample_interps', action='store_true', default=False, + help='Produce interpolation sheets and stick them in ' + 'the samples root? (default: %(default)s)') + parser.add_argument( + '--sample_sheet_folder_num', type=int, default=-1, + help='Number to use for the folder for these sample sheets ' + '(default: %(default)s)') + parser.add_argument( + '--sample_random', action='store_true', default=False, + help='Produce a single random sheet? (default: %(default)s)') + parser.add_argument( + '--sample_trunc_curves', type=str, default='', + help='Get inception metrics with a range of variances?' + 'To use this, specify a startpoint, step, and endpoint, e.g. ' + '--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, ' + 'endpoint of 1.0, and stepsize of 1.0. Note that this is ' + 'not exactly identical to using tf.truncated_normal, but should ' + 'have approximately the same effect. (default: %(default)s)') + parser.add_argument( + '--sample_inception_metrics', action='store_true', default=False, + help='Calculate Inception metrics with sample.py? (default: %(default)s)') + return parser + +# Convenience dicts +dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder, + 'I128': dset.ImageFolder, 'I256': dset.ImageFolder, + 'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5, + 'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5, + 'C10': dset.CIFAR10, 'C100': dset.CIFAR100} +imsize_dict = {'I32': 32, 'I32_hdf5': 32, + 'I64': 64, 'I64_hdf5': 64, + 'I128': 128, 'I128_hdf5': 128, + 'I256': 256, 'I256_hdf5': 256, + 'C10': 32, 'C100': 32} +root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5', + 'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5', + 'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5', + 'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5', + 'C10': 'cifar', 'C100': 'cifar'} +nclass_dict = {'I32': 1000, 'I32_hdf5': 1000, + 'I64': 1000, 'I64_hdf5': 1000, + 'I128': 1000, 'I128_hdf5': 1000, + 'I256': 1000, 'I256_hdf5': 1000, + 'C10': 10, 'C100': 100} +# Number of classes to put per sample sheet +classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50, + 'I64': 50, 'I64_hdf5': 50, + 'I128': 20, 'I128_hdf5': 20, + 'I256': 20, 'I256_hdf5': 20, + 'C10': 10, 'C100': 100} +activation_dict = {'inplace_relu': nn.ReLU(inplace=True), + 'relu': nn.ReLU(inplace=False), + 'ir': nn.ReLU(inplace=True),} + +class CenterCropLongEdge(object): + """Crops the given PIL Image on the long edge. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + Returns: + PIL Image: Cropped image. + """ + return transforms.functional.center_crop(img, min(img.size)) + + def __repr__(self): + return self.__class__.__name__ + +class RandomCropLongEdge(object): + """Crops the given PIL Image on the long edge with a random start point. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + Returns: + PIL Image: Cropped image. + """ + size = (min(img.size), min(img.size)) + # Only step forward along this edge if it's the long edge + i = (0 if size[0] == img.size[0] + else np.random.randint(low=0,high=img.size[0] - size[0])) + j = (0 if size[1] == img.size[1] + else np.random.randint(low=0,high=img.size[1] - size[1])) + return transforms.functional.crop(img, i, j, size[0], size[1]) + + def __repr__(self): + return self.__class__.__name__ + + +# multi-epoch Dataset sampler to avoid memory leakage and enable resumption of +# training from the same sample regardless of if we stop mid-epoch +class MultiEpochSampler(torch.utils.data.Sampler): + r"""Samples elements randomly over multiple epochs + + Arguments: + data_source (Dataset): dataset to sample from + num_epochs (int) : Number of times to loop over the dataset + start_itr (int) : which iteration to begin from + """ + + def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128): + self.data_source = data_source + self.num_samples = len(self.data_source) + self.num_epochs = num_epochs + self.start_itr = start_itr + self.batch_size = batch_size + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integeral " + "value, but got num_samples={}".format(self.num_samples)) + + def __iter__(self): + n = len(self.data_source) + # Determine number of epochs + num_epochs = int(np.ceil((n * self.num_epochs + - (self.start_itr * self.batch_size)) / float(n))) + # Sample all the indices, and then grab the last num_epochs index sets; + # This ensures if we're starting at epoch 4, we're still grabbing epoch 4's + # indices + out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:] + # Ignore the first start_itr % n indices of the first epoch + out[0] = out[0][(self.start_itr * self.batch_size % n):] + # if self.replacement: + # return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) + # return iter(.tolist()) + output = torch.cat(out).tolist() + print('Length dataset output is %d' % len(output)) + return iter(output) + + def __len__(self): + return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size + + +# Convenience function to centralize all data loaders +def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64, + num_workers=8, shuffle=True, load_in_mem=False, hdf5=False, + pin_memory=True, drop_last=True, start_itr=0, + num_epochs=500, use_multiepoch_sampler=False, + **kwargs): + + # Append /FILENAME.hdf5 to root if using hdf5 + data_root += '/%s' % root_dict[dataset] + print('Using dataset root location %s' % data_root) + + which_dataset = dset_dict[dataset] + norm_mean = [0.5,0.5,0.5] + norm_std = [0.5,0.5,0.5] + image_size = imsize_dict[dataset] + # For image folder datasets, name of the file where we store the precomputed + # image locations to avoid having to walk the dirs every time we load. + dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset} + + # HDF5 datasets have their own inbuilt transform, no need to train_transform + if 'hdf5' in dataset: + train_transform = None + else: + if augment: + print('Data will be augmented...') + if dataset in ['C10', 'C100']: + train_transform = [transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip()] + else: + train_transform = [RandomCropLongEdge(), + transforms.Resize(image_size), + transforms.RandomHorizontalFlip()] + else: + print('Data will not be augmented...') + if dataset in ['C10', 'C100']: + train_transform = [] + else: + train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)] + # train_transform = [transforms.Resize(image_size), transforms.CenterCrop] + train_transform = transforms.Compose(train_transform + [ + transforms.ToTensor(), + transforms.Normalize(norm_mean, norm_std)]) + train_set = which_dataset(root=data_root, transform=train_transform, + load_in_mem=load_in_mem, **dataset_kwargs) + + # Prepare loader; the loaders list is for forward compatibility with + # using validation / test splits. + loaders = [] + if use_multiepoch_sampler: + print('Using multiepoch sampler from start_itr %d...' % start_itr) + loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} + sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size) + train_loader = DataLoader(train_set, batch_size=batch_size, + sampler=sampler, **loader_kwargs) + else: + loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory, + 'drop_last': drop_last} # Default, drop last incomplete batch + train_loader = DataLoader(train_set, batch_size=batch_size, + shuffle=shuffle, **loader_kwargs) + loaders.append(train_loader) + return loaders + + +# Utility file to seed rngs +def seed_rng(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + + +# Utility to peg all roots to a base root +# If a base root folder is provided, peg all other root folders to it. +def update_config_roots(config): + if config['base_root']: + print('Pegging all root folders to base root %s' % config['base_root']) + for key in ['data', 'weights', 'logs', 'samples']: + config['%s_root' % key] = '%s/%s' % (config['base_root'], key) + return config + + +# Utility to prepare root folders if they don't exist; parent folder must exist +def prepare_root(config): + for key in ['weights_root', 'logs_root', 'samples_root']: + if not os.path.exists(config[key]): + print('Making directory %s for %s...' % (config[key], key)) + os.mkdir(config[key]) + + +# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using +# the parameters() and buffers() module functions, but for now this works +# with state_dicts using .copy_ +class ema(object): + def __init__(self, source, target, decay=0.9999, start_itr=0): + self.source = source + self.target = target + self.decay = decay + # Optional parameter indicating what iteration to start the decay at + self.start_itr = start_itr + # Initialize target's params to be source's + self.source_dict = self.source.state_dict() + self.target_dict = self.target.state_dict() + print('Initializing EMA parameters to be source parameters...') + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.source_dict[key].data) + # target_dict[key].data = source_dict[key].data # Doesn't work! + + def update(self, itr=None): + # If an iteration counter is provided and itr is less than the start itr, + # peg the ema weights to the underlying weights. + if itr and itr < self.start_itr: + decay = 0.0 + else: + decay = self.decay + with torch.no_grad(): + for key in self.source_dict: + self.target_dict[key].data.copy_(self.target_dict[key].data * decay + + self.source_dict[key].data * (1 - decay)) + + +# Apply modified ortho reg to a model +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes, and not in the blacklist + if len(param.shape) < 2 or any([param is item for item in blacklist]): + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + * (1. - torch.eye(w.shape[0], device=w.device)), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Default ortho reg +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def default_ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes & not in blacklist + if len(param.shape) < 2 or param in blacklist: + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + - torch.eye(w.shape[0], device=w.device), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Convenience utility to switch off requires_grad +def toggle_grad(model, on_or_off): + for param in model.parameters(): + param.requires_grad = on_or_off + + +# Function to join strings or ignore them +# Base string is the string to link "strings," while strings +# is a list of strings or Nones. +def join_strings(base_string, strings): + return base_string.join([item for item in strings if item]) + + +# Save a model's weights, optimizer, and the state_dict +def save_weights(G, D, state_dict, weights_root, experiment_name, + name_suffix=None, G_ema=None): + root = '/'.join([weights_root, experiment_name]) + if not os.path.exists(root): + os.mkdir(root) + if name_suffix: + print('Saving weights to %s/%s...' % (root, name_suffix)) + else: + print('Saving weights to %s...' % root) + torch.save(G.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))) + torch.save(G.optim.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))) + torch.save(D.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))) + torch.save(D.optim.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))) + torch.save(state_dict, + '%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix]))) + if G_ema is not None: + torch.save(G_ema.state_dict(), + '%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))) + + +# Load a model's weights, optimizer, and the state_dict +def load_weights(G, D, state_dict, weights_root, experiment_name, + name_suffix=None, G_ema=None, strict=True, load_optim=True): + root = '/'.join([weights_root, experiment_name]) + if name_suffix: + print('Loading %s weights from %s...' % (name_suffix, root)) + else: + print('Loading weights from %s...' % root) + if G is not None: + G.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))), + strict=strict) + if load_optim: + G.optim.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))) + if D is not None: + D.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))), + strict=strict) + if load_optim: + D.optim.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))) + # Load state dict + for item in state_dict: + state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item] + if G_ema is not None: + G_ema.load_state_dict( + torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))), + strict=strict) + + +''' MetricsLogger originally stolen from VoxNet source code. + Used for logging inception metrics''' +class MetricsLogger(object): + def __init__(self, fname, reinitialize=False): + self.fname = fname + self.reinitialize = reinitialize + if os.path.exists(self.fname): + if self.reinitialize: + print('{} exists, deleting...'.format(self.fname)) + os.remove(self.fname) + + def log(self, record=None, **kwargs): + """ + Assumption: no newlines in the input. + """ + if record is None: + record = {} + record.update(kwargs) + record['_stamp'] = time.time() + with open(self.fname, 'a') as f: + f.write(json.dumps(record, ensure_ascii=True) + '\n') + + +# Logstyle is either: +# '%#.#f' for floating point representation in text +# '%#.#e' for exponent representation in text +# 'npz' for output to npz # NOT YET SUPPORTED +# 'pickle' for output to a python pickle # NOT YET SUPPORTED +# 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED +class MyLogger(object): + def __init__(self, fname, reinitialize=False, logstyle='%3.3f'): + self.root = fname + if not os.path.exists(self.root): + os.mkdir(self.root) + self.reinitialize = reinitialize + self.metrics = [] + self.logstyle = logstyle # One of '%3.3f' or like '%3.3e' + + # Delete log if re-starting and log already exists + def reinit(self, item): + if os.path.exists('%s/%s.log' % (self.root, item)): + if self.reinitialize: + # Only print the removal mess + if 'sv' in item : + if not any('sv' in item for item in self.metrics): + print('Deleting singular value logs...') + else: + print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item))) + os.remove('%s/%s.log' % (self.root, item)) + + # Log in plaintext; this is designed for being read in MATLAB(sorry not sorry) + def log(self, itr, **kwargs): + for arg in kwargs: + if arg not in self.metrics: + if self.reinitialize: + self.reinit(arg) + self.metrics += [arg] + if self.logstyle == 'pickle': + print('Pickle not currently supported...') + # with open('%s/%s.log' % (self.root, arg), 'a') as f: + # pickle.dump(kwargs[arg], f) + elif self.logstyle == 'mat': + print('.mat logstyle not currently supported...') + else: + with open('%s/%s.log' % (self.root, arg), 'a') as f: + f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg])) + + +# Write some metadata to the logs directory +def write_metadata(logs_root, experiment_name, config, state_dict): + with open(('%s/%s/metalog.txt' % + (logs_root, experiment_name)), 'w') as writefile: + writefile.write('datetime: %s\n' % str(datetime.datetime.now())) + writefile.write('config: %s\n' % str(config)) + writefile.write('state: %s\n' %str(state_dict)) + + +""" +Very basic progress indicator to wrap an iterable in. + +Author: Jan Schlüter +Andy's adds: time elapsed in addition to ETA, makes it possible to add +estimated time to 1k iters instead of estimated time to completion. +""" +def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'): + """ + Returns a generator over `items`, printing the number and percentage of + items processed and the estimated remaining processing time before yielding + the next item. `total` gives the total number of items (required if `items` + has no length), and `min_delay` gives the minimum time in seconds between + subsequent prints. `desc` gives an optional prefix text (end with a space). + """ + total = total or len(items) + t_start = time.time() + t_last = 0 + for n, item in enumerate(items): + t_now = time.time() + if t_now - t_last > min_delay: + print("\r%s%d/%d (%6.2f%%)" % ( + desc, n+1, total, n / float(total) * 100), end=" ") + if n > 0: + + if displaytype == 's1k': # minutes/seconds for 1000 iters + next_1000 = n + (1000 - n%1000) + t_done = t_now - t_start + t_1k = t_done / n * next_1000 + outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60)) + print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") + else:# displaytype == 'eta': + t_done = t_now - t_start + t_total = t_done / n * total + outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60)) + print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") + + sys.stdout.flush() + t_last = t_now + yield item + t_total = time.time() - t_start + print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) + + divmod(t_total, 60))) + + +# Sample function for use with inception metrics +def sample(G, z_, y_, config): + with torch.no_grad(): + z_.sample_() + y_.sample_() + if config['parallel']: + G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_))) + else: + G_z = G(z_, G.shared(y_)) + return G_z, y_ + + +# Sample function for sample sheets +def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel, + samples_root, experiment_name, folder_number, z_=None): + # Prepare sample directory + if not os.path.isdir('%s/%s' % (samples_root, experiment_name)): + os.mkdir('%s/%s' % (samples_root, experiment_name)) + if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)): + os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)) + # loop over total number of sheets + for i in range(num_classes // classes_per_sheet): + ims = [] + y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda') + for j in range(samples_per_class): + if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0): + z_.sample_() + else: + z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda') + with torch.no_grad(): + if parallel: + o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y))) + else: + o = G(z_[:classes_per_sheet], G.shared(y)) + + ims += [o.data.cpu()] + # This line should properly unroll the images + out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2], + ims[0].shape[3]).data.float().cpu() + #out_ims = torch.from_numpy(out_ims.numpy()) ### NOTE: xcliu for torchvision + # The path for the samples + image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name, + folder_number, i) + torchvision.utils.save_image(out_ims, image_filename, + nrow=samples_per_class, normalize=True) + + +# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) +def interp(x0, x1, num_midpoints): + lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) + return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) + + +# interp sheet function +# Supports full, class-wise and intra-class interpolation +def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel, + samples_root, experiment_name, folder_number, sheet_number=0, + fix_z=False, fix_y=False, device='cuda'): + # Prepare zs and ys + if fix_z: # If fix Z, only sample 1 z per row + zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device) + zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z) + else: + zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device), + torch.randn(num_per_sheet, 1, G.dim_z, device=device), + num_midpoints).view(-1, G.dim_z) + if fix_y: # If fix y, only sample 1 z per row + ys = sample_1hot(num_per_sheet, num_classes) + ys = G.shared(ys).view(num_per_sheet, 1, -1) + ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1) + else: + ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), + G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), + num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1) + # Run the net--note that we've already passed y through G.shared. + if G.fp16: + zs = zs.half() + with torch.no_grad(): + if parallel: + out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu() + else: + out_ims = G(zs, ys).data.cpu() + interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '') + image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name, + folder_number, interp_style, + sheet_number) + torchvision.utils.save_image(out_ims, image_filename, + nrow=num_midpoints + 2, normalize=True) + + +# Convenience debugging function to print out gradnorms and shape from each layer +# May need to rewrite this so we can actually see which parameter is which +def print_grad_norms(net): + gradsums = [[float(torch.norm(param.grad).item()), + float(torch.norm(param).item()), param.shape] + for param in net.parameters()] + order = np.argsort([item[0] for item in gradsums]) + print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0], + gradsums[item_index][1], + str(gradsums[item_index][2])) + for item_index in order]) + + +# Get singular values to log. This will use the state dict to find them +# and substitute underscores for dots. +def get_SVs(net, prefix): + d = net.state_dict() + return {('%s_%s' % (prefix, key)).replace('.', '_') : + float(d[key].item()) + for key in d if 'sv' in key} + + +# Name an experiment based on its config +def name_from_config(config): + name = '_'.join([ + item for item in [ + 'Big%s' % config['which_train_fn'], + config['dataset'], + config['model'] if config['model'] != 'BigGAN' else None, + 'seed%d' % config['seed'], + 'Gch%d' % config['G_ch'], + 'Dch%d' % config['D_ch'], + 'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None, + 'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None, + 'bs%d' % config['batch_size'], + 'Gfp16' if config['G_fp16'] else None, + 'Dfp16' if config['D_fp16'] else None, + 'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None, + 'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None, + 'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None, + 'Glr%2.1e' % config['G_lr'], + 'Dlr%2.1e' % config['D_lr'], + 'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None, + 'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None, + 'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None, + 'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None, + 'Gnl%s' % config['G_nl'], + 'Dnl%s' % config['D_nl'], + 'Ginit%s' % config['G_init'], + 'Dinit%s' % config['D_init'], + 'G%s' % config['G_param'] if config['G_param'] != 'SN' else None, + 'D%s' % config['D_param'] if config['D_param'] != 'SN' else None, + 'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None, + 'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None, + 'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None, + 'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None, + config['norm_style'] if config['norm_style'] != 'bn' else None, + 'cr' if config['cross_replica'] else None, + 'Gshared' if config['G_shared'] else None, + 'hier' if config['hier'] else None, + 'ema' if config['ema'] else None, + config['name_suffix'] if config['name_suffix'] else None, + ] + if item is not None]) + # dogball + if config['hashname']: + return hashname(name) + else: + return name + + +# A simple function to produce a unique experiment name from the animal hashes. +def hashname(name): + h = hash(name) + a = h % len(animal_hash.a) + h = h // len(animal_hash.a) + b = h % len(animal_hash.b) + h = h // len(animal_hash.c) + c = h % len(animal_hash.c) + return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c] + + +# Get GPU memory, -i is the index +def query_gpu(indices): + os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv') + + +# Convenience function to count the number of parameters in a module +def count_parameters(module): + print('Number of parameters: {}'.format( + sum([p.data.nelement() for p in module.parameters()]))) + + +# Convenience function to sample an index, not actually a 1-hot +def sample_1hot(batch_size, num_classes, device='cuda'): + return torch.randint(low=0, high=num_classes, size=(batch_size,), + device=device, dtype=torch.int64, requires_grad=False) + + +# A highly simplified convenience class for sampling from distributions +# One could also use PyTorch's inbuilt distributions package. +# Note that this class requires initialization to proceed as +# x = Distribution(torch.randn(size)) +# x.init_distribution(dist_type, **dist_kwargs) +# x = x.to(device,dtype) +# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 +class Distribution(torch.Tensor): + # Init the params of the distribution + def init_distribution(self, dist_type, **kwargs): + self.dist_type = dist_type + self.dist_kwargs = kwargs + if self.dist_type == 'normal': + self.mean, self.var = kwargs['mean'], kwargs['var'] + elif self.dist_type == 'categorical': + self.num_categories = kwargs['num_categories'] + + def sample_(self): + if self.dist_type == 'normal': + self.normal_(self.mean, self.var) + elif self.dist_type == 'categorical': + self.random_(0, self.num_categories) + # return self.variable + + # Silly hack: overwrite the to() method to wrap the new object + # in a distribution as well + def to(self, *args, **kwargs): + new_obj = Distribution(self) + new_obj.init_distribution(self.dist_type, **self.dist_kwargs) + new_obj.data = super().to(*args, **kwargs) + return new_obj + + +# Convenience function to prepare a z and y vector +def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', + fp16=False,z_var=1.0): + z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) + z_.init_distribution('normal', mean=0, var=z_var) + z_ = z_.to(device,torch.float16 if fp16 else torch.float32) + + if fp16: + z_ = z_.half() + + y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) + y_.init_distribution('categorical',num_categories=nclasses) + y_ = y_.to(device, torch.int64) + return z_, y_ + + +def initiate_standing_stats(net): + for module in net.modules(): + if hasattr(module, 'accumulate_standing'): + module.reset_stats() + module.accumulate_standing = True + + +def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): + initiate_standing_stats(net) + net.train() + for i in range(num_accumulations): + with torch.no_grad(): + z.normal_() + y.random_(0, nclasses) + x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn + # Set to eval mode + net.eval() + + +# This version of Adam keeps an fp32 copy of the parameters and +# does all of the parameter updates in fp32, while still doing the +# forwards and backwards passes using fp16 (i.e. fp16 copies of the +# parameters and fp16 activations). +# +# Note that this calls .float().cuda() on the params. +import math +from torch.optim.optimizer import Optimizer +class Adam16(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + params = list(params) + super(Adam16, self).__init__(params, defaults) + + # Safety modification to make sure we floatify our state + def load_state_dict(self, state_dict): + super(Adam16, self).load_state_dict(state_dict) + for group in self.param_groups: + for p in group['params']: + self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float() + self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float() + self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float() + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data.float() + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + # Fp32 copy of the weights + state['fp32_p'] = p.data.float() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], state['fp32_p']) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['fp32_p'].addcdiv_(-step_size, exp_avg, denom) + p.data = state['fp32_p'].half() + + return loss diff --git a/text2image/BigGAN_utils/weights/README.md b/text2image/BigGAN_utils/weights/README.md new file mode 100644 index 0000000..4440e11 --- /dev/null +++ b/text2image/BigGAN_utils/weights/README.md @@ -0,0 +1,2 @@ +Download pre-trained weights from +https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing diff --git a/text2image/DiffAugment_pytorch.py b/text2image/DiffAugment_pytorch.py new file mode 100644 index 0000000..fa90feb --- /dev/null +++ b/text2image/DiffAugment_pytorch.py @@ -0,0 +1,102 @@ +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 + +import torch +import torch.nn.functional as F +import numpy as np + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): ### ratio: org: 0.125 + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() + return x + +def rand_resize(x, min_ratio=0.8, max_ratio=1.2): ### ratio: org: 0.125 + resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio + resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear') + org_size = x.shape[3] + #print('ORG:', x.shape) + #print('RESIZED:', resized_img.shape) + if int(resize_ratio*x.shape[3]) < x.shape[3]: + left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2. + left_pad = int(left_pad) + right_pad = x.shape[3] - left_pad - resized_img.shape[3] + #print('PAD:', left_pad, right_pad) + x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.) + #print('SMALL:', x.shape) + else: + left = (int(resize_ratio*x.shape[3])-x.shape[3])/2. + left = int(left) + #print('LEFT:', left) + x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])] + #print('LARGE:', x.shape) + assert x.shape[2] == org_size + assert x.shape[3] == org_size + + return x + + +def rand_cutout(x, ratio=0.5): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'resize': [rand_resize], + 'cutout': [rand_cutout], +} diff --git a/text2image/LICENSE b/text2image/LICENSE new file mode 100644 index 0000000..04205d8 --- /dev/null +++ b/text2image/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 gnobitab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/text2image/README.md b/text2image/README.md new file mode 100644 index 0000000..c56644a --- /dev/null +++ b/text2image/README.md @@ -0,0 +1,65 @@ +# FuseDream + +This repo contains code for our paper ([paper link](https://arxiv.org/abs/2112.01573)): + +**FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimization** + +by *Xingchao Liu, Chengyue Gong, Lemeng Wu, Shujian Zhang, Hao Su and Qiang Liu* from UCSD and UT Austin. + +![FuseDream](./imgs/header_img.png?raw=true "FuseDream") + +## Introduction +FuseDream uses pre-trained GANs (we support BigGAN-256 and BigGAN-512 for now) and CLIP to achieve high-fidelity text-to-image generation. + +## Requirements +Please use `pip` or `conda` to install the following packages: +`PyTorch==1.7.1, torchvision==0.8.2, lpips==0.1.4` and also the requirements from [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch). + +## Getting Started + +We transformed the pre-trained weights of BigGAN from TFHub to PyTorch. To save your time, you can download the transformed BigGAN checkpoints from: + +https://drive.google.com/drive/folders/1nJ3HmgYgeA9NZr-oU-enqbYeO7zBaANs?usp=sharing + +Put the checkpoints into `./BigGAN_utils/weights/` + +Run the following command to generate images from text query: + +`python fusedream_generator.py --text 'YOUR TEXT' --seed YOUR_SEED` + +For example, to get an image of a blue dog: + +`python fusedream_generator.py --text 'A photo of a blue dog.' --seed 1234` + +The generated image will be stored in `./samples` + +## Colab Notebook + +For a quick test of *FuseDream*, we provide Colab notebooks for [*FuseDream*(Single Image)](https://colab.research.google.com/drive/17qkzkoQQtzDRFaSCJQzIaNj88xjO9Rm9?usp=sharing) and *FuseDream-Composition*(TODO). Have fun! + +## Citations +If you use the code, please cite: + +```BibTex +@inproceedings{ +brock2018large, +title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis}, +author={Andrew Brock and Jeff Donahue and Karen Simonyan}, +booktitle={International Conference on Learning Representations}, +year={2019}, +url={https://openreview.net/forum?id=B1xsqj09Fm}, +} +``` + +and +```BibTex +@misc{ +liu2021fusedream, +title={FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimization}, +author={Xingchao Liu and Chengyue Gong and Lemeng Wu and Shujian Zhang and Hao Su and Qiang Liu}, +year={2021}, +eprint={2112.01573}, +archivePrefix={arXiv}, +primaryClass={cs.CV} +} +``` diff --git a/text2image/__init__.py b/text2image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text2image/fusedream_generator.py b/text2image/fusedream_generator.py new file mode 100644 index 0000000..8cf6091 --- /dev/null +++ b/text2image/fusedream_generator.py @@ -0,0 +1,35 @@ +import torch +from tqdm import tqdm +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import BigGAN_utils.utils as utils +import torch.nn.functional as F +from DiffAugment_pytorch import DiffAugment +import numpy as np +from fusedream_utils import FuseDreamBaseGenerator, get_G, save_image + +parser = utils.prepare_parser() +parser = utils.add_sample_parser(parser) +args = parser.parse_args() + +INIT_ITERS = 1000 +OPT_ITERS = 1000 + +utils.seed_rng(args.seed) + +sentence = args.text + +print('Generating:', sentence) +G, config = get_G(512) # Choose from 256 and 512 +generator = FuseDreamBaseGenerator(G, config, 10) +z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=5) + +z_cllt_save = torch.cat(z_cllt).cpu().numpy() +y_cllt_save = torch.cat(y_cllt).cpu().numpy() +img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=True, augment=True, opt_iters=OPT_ITERS, optimize_y=True) +score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20) +print('AugCLIP score:', score) +import os +if not os.path.exists('./samples'): + os.mkdir('./samples') +save_image(img, 'samples/fusedream_%s_seed_%d_score_%.4f.png'%(sentence, args.seed, score)) + diff --git a/text2image/fusedream_utils.py b/text2image/fusedream_utils.py new file mode 100644 index 0000000..ade612b --- /dev/null +++ b/text2image/fusedream_utils.py @@ -0,0 +1,308 @@ +import torch +from tqdm import tqdm +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import torchvision +import BigGAN_utils.utils as utils +import clip +import torch.nn.functional as F +from DiffAugment_pytorch import DiffAugment +import numpy as np +import lpips +import os +current_path = os.path.dirname(__file__) + +LATENT_NOISE = 0.01 +Z_THRES = 2.0 +POLICY = 'color,translation,resize,cutout' +TEST_POLICY = 'color,translation,resize,cutout' +mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() +std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() + +def AugmentLoss(img, clip_model, text, replicate=10, interp_mode='bilinear', policy=POLICY): + + clip_c = clip_model.logit_scale.exp() + img_aug = DiffAugment(img.repeat(replicate, 1, 1, 1), policy=policy) + img_aug = (img_aug+1.)/2. + img_aug = F.interpolate(img_aug, size=224, mode=interp_mode) + img_aug.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = clip_model(img_aug, text) + logits_per_image = logits_per_image / clip_c + concept_loss = (-1.) * logits_per_image + + return concept_loss.mean(dim=0, keepdim=False) + +def NaiveSemanticLoss(img, clip_model, text, interp_mode='bilinear'): + + clip_c = clip_model.logit_scale.exp() + img = (img+1.)/2. + img = F.interpolate(img, size=224, mode=interp_mode) + img.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = clip_model(img, text) + logits_per_image = logits_per_image / clip_c + concept_loss = (-1.) * logits_per_image + + return concept_loss.mean(dim=0, keepdim=False) + +def get_gaussian_mask(size=256): + x, y = np.meshgrid(np.linspace(-1,1, size), np.linspace(-1,1,size)) + dst = np.sqrt(x*x+y*y) + + # Intializing sigma and muu + sigma = 1 + muu = 0.000 + + # Calculating Gaussian array + gauss = np.exp(-( (dst-muu)**2 / ( 2.0 * sigma**2 ) ) ) + + return gauss + +def save_image(img, path, n_per_row=1): + with torch.no_grad(): + torchvision.utils.save_image( + torch.from_numpy(img.cpu().numpy()), ##hack, to turn Distribution back to tensor + path, + nrow=n_per_row, + normalize=True, + ) + +def get_G(resolution=256): + if resolution == 256: + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + + # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh. + config["resolution"] = utils.imsize_dict["I128_hdf5"] + config["n_classes"] = utils.nclass_dict["I128_hdf5"] + config["G_activation"] = utils.activation_dict["inplace_relu"] + config["D_activation"] = utils.activation_dict["inplace_relu"] + config["G_attn"] = "128" + config["D_attn"] = "128" + config["G_ch"] = 96 + config["D_ch"] = 96 + config["hier"] = True + config["dim_z"] = 140 + config["shared_dim"] = 128 + config["G_shared"] = True + config = utils.update_config_roots(config) + config["skip_init"] = True + config["no_optim"] = True + config["device"] = "cuda" + config["resolution"] = 256 + + # Set up cudnn.benchmark for free speed. + torch.backends.cudnn.benchmark = True + + # Import the model. + model = __import__(config["model"]) + G = model.Generator(**config).to(config["device"]) + utils.count_parameters(G) + + # Load weights. + weights_path = f"{current_path}/BigGAN_utils/weights/biggan-256.pth" # Change this. + G.load_state_dict(torch.load(weights_path), strict=False) + elif resolution == 512: + parser = utils.prepare_parser() + parser = utils.add_sample_parser(parser) + config = vars(parser.parse_args()) + + # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs128x8.sh. + config["resolution"] = 512 + config["n_classes"] = utils.nclass_dict["I128_hdf5"] + config["G_activation"] = utils.activation_dict["inplace_relu"] + config["D_activation"] = utils.activation_dict["inplace_relu"] + config["G_attn"] = "64" + config["D_attn"] = "64" + config["G_ch"] = 96 + config["D_ch"] = 64 + config["hier"] = True + config["dim_z"] = 128 + config["shared_dim"] = 128 + config["G_shared"] = True + config = utils.update_config_roots(config) + config["skip_init"] = True + config["no_optim"] = True + config["device"] = "cuda" + + # Set up cudnn.benchmark for free speed. + torch.backends.cudnn.benchmark = True + + # Import the model. + model = __import__(config["model"]) + #print(config["model"]) + G = model.Generator(**config).to(config["device"]) + utils.count_parameters(G) + #print('G parameters:') + #for p, m in G.named_parameters(): + # print(p) + # Load weights. + weights_path = f"{current_path}/BigGAN_utils/weights/biggan-512.pth" # Change this. + G.load_state_dict(torch.load(weights_path), strict=False) + + return G, config + +class FuseDreamBaseGenerator(): + def __init__(self, G, G_config, G_batch_size=10, clip_mode="ViT-B/32", interp_mode='bilinear'): + + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.G = G + self.clip_model, _ = clip.load(clip_mode, device=device) + + (self.z_, self.y_) = utils.prepare_z_y( + G_batch_size, + self.G.dim_z, + G_config["n_classes"], + device=G_config["device"], + fp16=G_config["G_fp16"], + z_var=G_config["z_var"], + ) + + self.G.eval() + + for p in self.G.parameters(): + p.requires_grad = False + for p in self.clip_model.parameters(): + p.requires_grad = False + + self.interp_mode = interp_mode + + def generate_basis(self, text, init_iters=500, num_basis=5): + text_tok = clip.tokenize([text]).to(self.device) + clip_c = self.clip_model.logit_scale.exp() + + z_init_cllt = [] + y_init_cllt = [] + z_init = None + y_init = None + score_init = None + with torch.no_grad(): + for i in tqdm(range(init_iters)): + self.z_.sample_() + self.y_.sample_() + + self.z_.data = torch.clamp(self.z_.data.detach().clone(), min=-Z_THRES, max=Z_THRES) + + image_tensors = self.G(self.z_, self.G.shared(self.y_)) + image_tensors = (image_tensors+1.) / 2. + image_tensors = F.interpolate(image_tensors, size=224, mode=self.interp_mode) + image_tensors.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + logits_per_image, logits_per_text = self.clip_model(image_tensors, text_tok) + logits_per_image = logits_per_image/clip_c + if z_init is None: + z_init = self.z_.data.detach().clone() + y_init = self.y_.data.detach().clone() + score_init = logits_per_image.squeeze() + else: + z_init = torch.cat([z_init, self.z_.data.detach().clone()], dim=0) + y_init = torch.cat([y_init, self.y_.data.detach().clone()], dim=0) + score_init = torch.cat([score_init, logits_per_image.squeeze()]) + + sorted, indices = torch.sort(score_init, descending=True) + z_init = z_init[indices] + y_init = y_init[indices] + score_init = score_init[indices] + z_init = z_init[:num_basis] + y_init = y_init[:num_basis] + score_init = score_init[:num_basis] + + #save_image(self.G(z_init, self.G.shared(y_init)), 'samples/init_%s.png'%text, 1) + + z_init_cllt.append(z_init.detach().clone()) + y_init_cllt.append(self.G.shared(y_init.detach().clone())) + + return z_init_cllt, y_init_cllt + + + def optimize_clip_score(self, z_init_cllt, y_init_cllt, text, latent_noise=False, augment=True, opt_iters=500, optimize_y=False): + + text_tok = clip.tokenize([text]).to(self.device) + clip_c = self.clip_model.logit_scale.exp() + + z_init_ans = torch.stack(z_init_cllt) + y_init_ans = torch.stack(y_init_cllt) + z_init_ans = z_init_ans.view(-1, z_init_ans.shape[-1]) + y_init_ans = y_init_ans.view(-1, y_init_ans.shape[-1]) + + w_z = torch.randn((z_init_ans.shape[0], z_init_ans.shape[1])).to(self.device) + w_y = torch.randn((y_init_ans.shape[0], y_init_ans.shape[1])).to(self.device) + w_z.requires_grad = True + w_y.requires_grad = True + + opt_y = torch.zeros(y_init_ans.shape).to(self.device) + opt_y.data = y_init_ans.data.detach().clone() + opt_z = torch.zeros(z_init_ans.shape).to(self.device) + opt_z.data = z_init_ans.data.detach().clone() + opt_z.requires_grad = True + + if not optimize_y: + optimizer = torch.optim.Adam([w_z, w_y, opt_z], lr=5e-3, weight_decay=0.0) + else: + opt_y.requires_grad = True + optimizer = torch.optim.Adam([w_z, w_y,opt_y,opt_z], lr=5e-3, weight_decay=0.0) + + for i in tqdm(range(opt_iters)): + #print(w_z.shape, w_y.shape) + optimizer.zero_grad() + + if not latent_noise: + s_z = torch.softmax(w_z, dim=0) + s_y = torch.softmax(w_y, dim=0) + #print(s_z) + + cur_z = s_z * opt_z + cur_y = s_y * opt_y + cur_z = cur_z.sum(dim=0, keepdim=True) + cur_y = cur_y.sum(dim=0, keepdim=True) + + image_tensors = self.G(cur_z, cur_y) + else: + s_z = torch.softmax(w_z, dim=0) + s_y = torch.softmax(w_y, dim=0) + + cur_z = s_z * opt_z + cur_y = s_y * opt_y + cur_z = cur_z.sum(dim=0, keepdim=True) + cur_y = cur_y.sum(dim=0, keepdim=True) + cur_z_aug = cur_z + torch.randn(cur_z.shape).to(cur_z.device) * LATENT_NOISE + cur_y_aug = cur_y + torch.randn(cur_y.shape).to(cur_y.device) * LATENT_NOISE + + image_tensors = self.G(cur_z_aug, cur_y_aug) + + loss = 0.0 + for j in range(image_tensors.shape[0]): + if augment: + loss = loss + AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode) + else: + loss = loss + NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok) + + loss.backward() + optimizer.step() + + opt_z.data = torch.clamp(opt_z.data.detach().clone(), min=-Z_THRES, max=Z_THRES) + + z_init_ans = cur_z.detach().clone() + y_init_ans = cur_y.detach().clone() + + # save_image(self.G(z_init_ans, y_init_ans), '/home/zhaojh/workspace/computer_vision/opt_%s.png'%text, 1) + return self.G(z_init_ans, y_init_ans), z_init_ans, y_init_ans + + def measureAugCLIP(self, z, y, text, augment=False, num_samples=20): + text_tok = clip.tokenize([text]).to(self.device) + avg_loss = 0.0 + for itr in range(num_samples): + image_tensors = self.G(z, y) + + for j in range(image_tensors.shape[0]): + if augment: + loss = AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode, policy=TEST_POLICY) + else: + loss = NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok) + avg_loss += loss.item() + + avg_loss /= num_samples + return avg_loss * (-1.) + diff --git a/text2image/run_text2img.py b/text2image/run_text2img.py new file mode 100644 index 0000000..3a160b6 --- /dev/null +++ b/text2image/run_text2img.py @@ -0,0 +1,30 @@ +import numpy as np +from logzero import logger +import torch +from torchvision.utils import make_grid +from BigGAN_utils import utils +from fusedream_utils import FuseDreamBaseGenerator, get_G +from PIL import Image + +INIT_ITERS = 1000 +OPT_ITERS = 1000 +NUM_BASIS = 5 +MODEL = "biggan-512" +SEED = 1884 + +def text2image(sentence:str): + utils.seed_rng(SEED) + logger.info(f'Generating: {sentence}') + G, config = get_G(512) + generator = FuseDreamBaseGenerator(G, config, 10) + z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=NUM_BASIS) + + z_cllt_save = torch.cat(z_cllt).cpu().numpy() + y_cllt_save = torch.cat(y_cllt).cpu().numpy() + img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=False, augment=True, opt_iters=OPT_ITERS, optimize_y=True) + ## Set latent_noise = True yields slightly higher AugCLIP score, but slightly lower image quality. We set it to False for dogs. + score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20) + grid = make_grid(img, nrow=1, normalize=True) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + return im \ No newline at end of file