ai_platform_cv/text2image/run_text2img.py

30 lines
1.2 KiB
Python

import numpy as np
from logzero import logger
import torch
from torchvision.utils import make_grid
from BigGAN_utils import utils
from fusedream_utils import FuseDreamBaseGenerator, get_G
from PIL import Image
INIT_ITERS = 1000
OPT_ITERS = 1000
NUM_BASIS = 5
MODEL = "biggan-512"
SEED = 1884
def text2image(sentence:str):
utils.seed_rng(SEED)
logger.info(f'Generating: {sentence}')
G, config = get_G(512)
generator = FuseDreamBaseGenerator(G, config, 10)
z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=NUM_BASIS)
z_cllt_save = torch.cat(z_cllt).cpu().numpy()
y_cllt_save = torch.cat(y_cllt).cpu().numpy()
img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=False, augment=True, opt_iters=OPT_ITERS, optimize_y=True)
## Set latent_noise = True yields slightly higher AugCLIP score, but slightly lower image quality. We set it to False for dogs.
score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20)
grid = make_grid(img, nrow=1, normalize=True)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
return im