30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
import numpy as np
|
|
from logzero import logger
|
|
import torch
|
|
from torchvision.utils import make_grid
|
|
from BigGAN_utils import utils
|
|
from fusedream_utils import FuseDreamBaseGenerator, get_G
|
|
from PIL import Image
|
|
|
|
INIT_ITERS = 1000
|
|
OPT_ITERS = 1000
|
|
NUM_BASIS = 5
|
|
MODEL = "biggan-512"
|
|
SEED = 1884
|
|
|
|
def text2image(sentence:str):
|
|
utils.seed_rng(SEED)
|
|
logger.info(f'Generating: {sentence}')
|
|
G, config = get_G(512)
|
|
generator = FuseDreamBaseGenerator(G, config, 10)
|
|
z_cllt, y_cllt = generator.generate_basis(sentence, init_iters=INIT_ITERS, num_basis=NUM_BASIS)
|
|
|
|
z_cllt_save = torch.cat(z_cllt).cpu().numpy()
|
|
y_cllt_save = torch.cat(y_cllt).cpu().numpy()
|
|
img, z, y = generator.optimize_clip_score(z_cllt, y_cllt, sentence, latent_noise=False, augment=True, opt_iters=OPT_ITERS, optimize_y=True)
|
|
## Set latent_noise = True yields slightly higher AugCLIP score, but slightly lower image quality. We set it to False for dogs.
|
|
score = generator.measureAugCLIP(z, y, sentence, augment=True, num_samples=20)
|
|
grid = make_grid(img, nrow=1, normalize=True)
|
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
|
im = Image.fromarray(ndarr)
|
|
return im |