309 lines
12 KiB
Python
309 lines
12 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
|
import torchvision
|
|
import BigGAN_utils.utils as utils
|
|
import clip
|
|
import torch.nn.functional as F
|
|
from DiffAugment_pytorch import DiffAugment
|
|
import numpy as np
|
|
import lpips
|
|
import os
|
|
current_path = os.path.dirname(__file__)
|
|
|
|
LATENT_NOISE = 0.01
|
|
Z_THRES = 2.0
|
|
POLICY = 'color,translation,resize,cutout'
|
|
TEST_POLICY = 'color,translation,resize,cutout'
|
|
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
|
|
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
|
|
|
|
def AugmentLoss(img, clip_model, text, replicate=10, interp_mode='bilinear', policy=POLICY):
|
|
|
|
clip_c = clip_model.logit_scale.exp()
|
|
img_aug = DiffAugment(img.repeat(replicate, 1, 1, 1), policy=policy)
|
|
img_aug = (img_aug+1.)/2.
|
|
img_aug = F.interpolate(img_aug, size=224, mode=interp_mode)
|
|
img_aug.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
|
|
|
logits_per_image, logits_per_text = clip_model(img_aug, text)
|
|
logits_per_image = logits_per_image / clip_c
|
|
concept_loss = (-1.) * logits_per_image
|
|
|
|
return concept_loss.mean(dim=0, keepdim=False)
|
|
|
|
def NaiveSemanticLoss(img, clip_model, text, interp_mode='bilinear'):
|
|
|
|
clip_c = clip_model.logit_scale.exp()
|
|
img = (img+1.)/2.
|
|
img = F.interpolate(img, size=224, mode=interp_mode)
|
|
img.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
|
|
|
logits_per_image, logits_per_text = clip_model(img, text)
|
|
logits_per_image = logits_per_image / clip_c
|
|
concept_loss = (-1.) * logits_per_image
|
|
|
|
return concept_loss.mean(dim=0, keepdim=False)
|
|
|
|
def get_gaussian_mask(size=256):
|
|
x, y = np.meshgrid(np.linspace(-1,1, size), np.linspace(-1,1,size))
|
|
dst = np.sqrt(x*x+y*y)
|
|
|
|
# Intializing sigma and muu
|
|
sigma = 1
|
|
muu = 0.000
|
|
|
|
# Calculating Gaussian array
|
|
gauss = np.exp(-( (dst-muu)**2 / ( 2.0 * sigma**2 ) ) )
|
|
|
|
return gauss
|
|
|
|
def save_image(img, path, n_per_row=1):
|
|
with torch.no_grad():
|
|
torchvision.utils.save_image(
|
|
torch.from_numpy(img.cpu().numpy()), ##hack, to turn Distribution back to tensor
|
|
path,
|
|
nrow=n_per_row,
|
|
normalize=True,
|
|
)
|
|
|
|
def get_G(resolution=256):
|
|
if resolution == 256:
|
|
parser = utils.prepare_parser()
|
|
parser = utils.add_sample_parser(parser)
|
|
config = vars(parser.parse_args())
|
|
|
|
# See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh.
|
|
config["resolution"] = utils.imsize_dict["I128_hdf5"]
|
|
config["n_classes"] = utils.nclass_dict["I128_hdf5"]
|
|
config["G_activation"] = utils.activation_dict["inplace_relu"]
|
|
config["D_activation"] = utils.activation_dict["inplace_relu"]
|
|
config["G_attn"] = "128"
|
|
config["D_attn"] = "128"
|
|
config["G_ch"] = 96
|
|
config["D_ch"] = 96
|
|
config["hier"] = True
|
|
config["dim_z"] = 140
|
|
config["shared_dim"] = 128
|
|
config["G_shared"] = True
|
|
config = utils.update_config_roots(config)
|
|
config["skip_init"] = True
|
|
config["no_optim"] = True
|
|
config["device"] = "cuda"
|
|
config["resolution"] = 256
|
|
|
|
# Set up cudnn.benchmark for free speed.
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# Import the model.
|
|
model = __import__(config["model"])
|
|
G = model.Generator(**config).to(config["device"])
|
|
utils.count_parameters(G)
|
|
|
|
# Load weights.
|
|
weights_path = f"{current_path}/BigGAN_utils/weights/biggan-256.pth" # Change this.
|
|
G.load_state_dict(torch.load(weights_path), strict=False)
|
|
elif resolution == 512:
|
|
parser = utils.prepare_parser()
|
|
parser = utils.add_sample_parser(parser)
|
|
config = vars(parser.parse_args())
|
|
|
|
# See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs128x8.sh.
|
|
config["resolution"] = 512
|
|
config["n_classes"] = utils.nclass_dict["I128_hdf5"]
|
|
config["G_activation"] = utils.activation_dict["inplace_relu"]
|
|
config["D_activation"] = utils.activation_dict["inplace_relu"]
|
|
config["G_attn"] = "64"
|
|
config["D_attn"] = "64"
|
|
config["G_ch"] = 96
|
|
config["D_ch"] = 64
|
|
config["hier"] = True
|
|
config["dim_z"] = 128
|
|
config["shared_dim"] = 128
|
|
config["G_shared"] = True
|
|
config = utils.update_config_roots(config)
|
|
config["skip_init"] = True
|
|
config["no_optim"] = True
|
|
config["device"] = "cuda"
|
|
|
|
# Set up cudnn.benchmark for free speed.
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# Import the model.
|
|
model = __import__(config["model"])
|
|
#print(config["model"])
|
|
G = model.Generator(**config).to(config["device"])
|
|
utils.count_parameters(G)
|
|
#print('G parameters:')
|
|
#for p, m in G.named_parameters():
|
|
# print(p)
|
|
# Load weights.
|
|
weights_path = f"{current_path}/BigGAN_utils/weights/biggan-512.pth" # Change this.
|
|
G.load_state_dict(torch.load(weights_path), strict=False)
|
|
|
|
return G, config
|
|
|
|
class FuseDreamBaseGenerator():
|
|
def __init__(self, G, G_config, G_batch_size=10, clip_mode="ViT-B/32", interp_mode='bilinear'):
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.device = device
|
|
self.G = G
|
|
self.clip_model, _ = clip.load(clip_mode, device=device)
|
|
|
|
(self.z_, self.y_) = utils.prepare_z_y(
|
|
G_batch_size,
|
|
self.G.dim_z,
|
|
G_config["n_classes"],
|
|
device=G_config["device"],
|
|
fp16=G_config["G_fp16"],
|
|
z_var=G_config["z_var"],
|
|
)
|
|
|
|
self.G.eval()
|
|
|
|
for p in self.G.parameters():
|
|
p.requires_grad = False
|
|
for p in self.clip_model.parameters():
|
|
p.requires_grad = False
|
|
|
|
self.interp_mode = interp_mode
|
|
|
|
def generate_basis(self, text, init_iters=500, num_basis=5):
|
|
text_tok = clip.tokenize([text]).to(self.device)
|
|
clip_c = self.clip_model.logit_scale.exp()
|
|
|
|
z_init_cllt = []
|
|
y_init_cllt = []
|
|
z_init = None
|
|
y_init = None
|
|
score_init = None
|
|
with torch.no_grad():
|
|
for i in tqdm(range(init_iters)):
|
|
self.z_.sample_()
|
|
self.y_.sample_()
|
|
|
|
self.z_.data = torch.clamp(self.z_.data.detach().clone(), min=-Z_THRES, max=Z_THRES)
|
|
|
|
image_tensors = self.G(self.z_, self.G.shared(self.y_))
|
|
image_tensors = (image_tensors+1.) / 2.
|
|
image_tensors = F.interpolate(image_tensors, size=224, mode=self.interp_mode)
|
|
image_tensors.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
|
|
|
logits_per_image, logits_per_text = self.clip_model(image_tensors, text_tok)
|
|
logits_per_image = logits_per_image/clip_c
|
|
if z_init is None:
|
|
z_init = self.z_.data.detach().clone()
|
|
y_init = self.y_.data.detach().clone()
|
|
score_init = logits_per_image.squeeze()
|
|
else:
|
|
z_init = torch.cat([z_init, self.z_.data.detach().clone()], dim=0)
|
|
y_init = torch.cat([y_init, self.y_.data.detach().clone()], dim=0)
|
|
score_init = torch.cat([score_init, logits_per_image.squeeze()])
|
|
|
|
sorted, indices = torch.sort(score_init, descending=True)
|
|
z_init = z_init[indices]
|
|
y_init = y_init[indices]
|
|
score_init = score_init[indices]
|
|
z_init = z_init[:num_basis]
|
|
y_init = y_init[:num_basis]
|
|
score_init = score_init[:num_basis]
|
|
|
|
#save_image(self.G(z_init, self.G.shared(y_init)), 'samples/init_%s.png'%text, 1)
|
|
|
|
z_init_cllt.append(z_init.detach().clone())
|
|
y_init_cllt.append(self.G.shared(y_init.detach().clone()))
|
|
|
|
return z_init_cllt, y_init_cllt
|
|
|
|
|
|
def optimize_clip_score(self, z_init_cllt, y_init_cllt, text, latent_noise=False, augment=True, opt_iters=500, optimize_y=False):
|
|
|
|
text_tok = clip.tokenize([text]).to(self.device)
|
|
clip_c = self.clip_model.logit_scale.exp()
|
|
|
|
z_init_ans = torch.stack(z_init_cllt)
|
|
y_init_ans = torch.stack(y_init_cllt)
|
|
z_init_ans = z_init_ans.view(-1, z_init_ans.shape[-1])
|
|
y_init_ans = y_init_ans.view(-1, y_init_ans.shape[-1])
|
|
|
|
w_z = torch.randn((z_init_ans.shape[0], z_init_ans.shape[1])).to(self.device)
|
|
w_y = torch.randn((y_init_ans.shape[0], y_init_ans.shape[1])).to(self.device)
|
|
w_z.requires_grad = True
|
|
w_y.requires_grad = True
|
|
|
|
opt_y = torch.zeros(y_init_ans.shape).to(self.device)
|
|
opt_y.data = y_init_ans.data.detach().clone()
|
|
opt_z = torch.zeros(z_init_ans.shape).to(self.device)
|
|
opt_z.data = z_init_ans.data.detach().clone()
|
|
opt_z.requires_grad = True
|
|
|
|
if not optimize_y:
|
|
optimizer = torch.optim.Adam([w_z, w_y, opt_z], lr=5e-3, weight_decay=0.0)
|
|
else:
|
|
opt_y.requires_grad = True
|
|
optimizer = torch.optim.Adam([w_z, w_y,opt_y,opt_z], lr=5e-3, weight_decay=0.0)
|
|
|
|
for i in tqdm(range(opt_iters)):
|
|
#print(w_z.shape, w_y.shape)
|
|
optimizer.zero_grad()
|
|
|
|
if not latent_noise:
|
|
s_z = torch.softmax(w_z, dim=0)
|
|
s_y = torch.softmax(w_y, dim=0)
|
|
#print(s_z)
|
|
|
|
cur_z = s_z * opt_z
|
|
cur_y = s_y * opt_y
|
|
cur_z = cur_z.sum(dim=0, keepdim=True)
|
|
cur_y = cur_y.sum(dim=0, keepdim=True)
|
|
|
|
image_tensors = self.G(cur_z, cur_y)
|
|
else:
|
|
s_z = torch.softmax(w_z, dim=0)
|
|
s_y = torch.softmax(w_y, dim=0)
|
|
|
|
cur_z = s_z * opt_z
|
|
cur_y = s_y * opt_y
|
|
cur_z = cur_z.sum(dim=0, keepdim=True)
|
|
cur_y = cur_y.sum(dim=0, keepdim=True)
|
|
cur_z_aug = cur_z + torch.randn(cur_z.shape).to(cur_z.device) * LATENT_NOISE
|
|
cur_y_aug = cur_y + torch.randn(cur_y.shape).to(cur_y.device) * LATENT_NOISE
|
|
|
|
image_tensors = self.G(cur_z_aug, cur_y_aug)
|
|
|
|
loss = 0.0
|
|
for j in range(image_tensors.shape[0]):
|
|
if augment:
|
|
loss = loss + AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode)
|
|
else:
|
|
loss = loss + NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
opt_z.data = torch.clamp(opt_z.data.detach().clone(), min=-Z_THRES, max=Z_THRES)
|
|
|
|
z_init_ans = cur_z.detach().clone()
|
|
y_init_ans = cur_y.detach().clone()
|
|
|
|
# save_image(self.G(z_init_ans, y_init_ans), '/home/zhaojh/workspace/computer_vision/opt_%s.png'%text, 1)
|
|
return self.G(z_init_ans, y_init_ans), z_init_ans, y_init_ans
|
|
|
|
def measureAugCLIP(self, z, y, text, augment=False, num_samples=20):
|
|
text_tok = clip.tokenize([text]).to(self.device)
|
|
avg_loss = 0.0
|
|
for itr in range(num_samples):
|
|
image_tensors = self.G(z, y)
|
|
|
|
for j in range(image_tensors.shape[0]):
|
|
if augment:
|
|
loss = AugmentLoss(image_tensors[j:(j+1)], self.clip_model, text_tok, replicate=50, interp_mode=self.interp_mode, policy=TEST_POLICY)
|
|
else:
|
|
loss = NaiveSemanticLoss(image_tensors[j:(j+1)], self.clip_model, text_tok)
|
|
avg_loss += loss.item()
|
|
|
|
avg_loss /= num_samples
|
|
return avg_loss * (-1.)
|
|
|