import numpy as np from logzero import logger import torch from torchvision.utils import make_grid from BigGAN_utils import utils from fusedream_utils import FuseDreamBaseGenerator, get_G from PIL import Image INIT_ITERS = 1000 OPT_ITERS = 1000 NUM_BASIS = 5 MODEL = "biggan-512" SEED = 1884 def text2image(sentence:str): utils.seed_rng(SEED) logger.info(f'Generating: {sentence}') G, config = get_G(512) generator = FuseDreamBaseGenerator(G, config, 10) z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=NUM_BASIS) z_cllt_save = torch.cat(z_cllt).cpu().numpy() y_cllt_save = torch.cat(y_cllt).cpu().numpy() img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=False, augment=True, opt_iters=OPT_ITERS, optimize_y=True) ## Set latent_noise = True yields slightly higher AugCLIP score, but slightly lower image quality. We set it to False for dogs. score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20) grid = make_grid(img, nrow=1, normalize=True) ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() im = Image.fromarray(ndarr) return im