36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
|
import BigGAN_utils.utils as utils
|
|
import torch.nn.functional as F
|
|
from DiffAugment_pytorch import DiffAugment
|
|
import numpy as np
|
|
from fusedream_utils import FuseDreamBaseGenerator, get_G, save_image
|
|
|
|
parser = utils.prepare_parser()
|
|
parser = utils.add_sample_parser(parser)
|
|
args = parser.parse_args()
|
|
|
|
INIT_ITERS = 1000
|
|
OPT_ITERS = 1000
|
|
|
|
utils.seed_rng(args.seed)
|
|
|
|
sentence = args.text
|
|
|
|
print('Generating:', sentence)
|
|
G, config = get_G(512) # Choose from 256 and 512
|
|
generator = FuseDreamBaseGenerator(G, config, 10)
|
|
z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=5)
|
|
|
|
z_cllt_save = torch.cat(z_cllt).cpu().numpy()
|
|
y_cllt_save = torch.cat(y_cllt).cpu().numpy()
|
|
img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=True, augment=True, opt_iters=OPT_ITERS, optimize_y=True)
|
|
score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20)
|
|
print('AugCLIP score:', score)
|
|
import os
|
|
if not os.path.exists('./samples'):
|
|
os.mkdir('./samples')
|
|
save_image(img, 'samples/fusedream_%s_seed_%d_score_%.4f.png'%(sentence, args.seed, score))
|
|
|