390 lines
12 KiB
Python
390 lines
12 KiB
Python
# BigGAN V1:
|
|
# This is now deprecated code used for porting the TFHub modules to pytorch,
|
|
# included here for reference only.
|
|
import numpy as np
|
|
import torch
|
|
from scipy.stats import truncnorm
|
|
from torch import nn
|
|
from torch.nn import Parameter
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def l2normalize(v, eps=1e-4):
|
|
return v / (v.norm() + eps)
|
|
|
|
|
|
def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None):
|
|
state = None if seed is None else np.random.RandomState(seed)
|
|
values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state)
|
|
return truncation * values
|
|
|
|
|
|
def denorm(x):
|
|
out = (x + 1) / 2
|
|
return out.clamp_(0, 1)
|
|
|
|
|
|
class SpectralNorm(nn.Module):
|
|
def __init__(self, module, name='weight', power_iterations=1):
|
|
super(SpectralNorm, self).__init__()
|
|
self.module = module
|
|
self.name = name
|
|
self.power_iterations = power_iterations
|
|
if not self._made_params():
|
|
self._make_params()
|
|
|
|
def _update_u_v(self):
|
|
u = getattr(self.module, self.name + "_u")
|
|
v = getattr(self.module, self.name + "_v")
|
|
w = getattr(self.module, self.name + "_bar")
|
|
|
|
height = w.data.shape[0]
|
|
_w = w.view(height, -1)
|
|
for _ in range(self.power_iterations):
|
|
v = l2normalize(torch.matmul(_w.t(), u))
|
|
u = l2normalize(torch.matmul(_w, v))
|
|
|
|
sigma = u.dot((_w).mv(v))
|
|
setattr(self.module, self.name, w / sigma.expand_as(w))
|
|
|
|
def _made_params(self):
|
|
try:
|
|
getattr(self.module, self.name + "_u")
|
|
getattr(self.module, self.name + "_v")
|
|
getattr(self.module, self.name + "_bar")
|
|
return True
|
|
except AttributeError:
|
|
return False
|
|
|
|
def _make_params(self):
|
|
w = getattr(self.module, self.name)
|
|
|
|
height = w.data.shape[0]
|
|
width = w.view(height, -1).data.shape[1]
|
|
|
|
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
|
v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
|
u.data = l2normalize(u.data)
|
|
v.data = l2normalize(v.data)
|
|
w_bar = Parameter(w.data)
|
|
|
|
del self.module._parameters[self.name]
|
|
self.module.register_parameter(self.name + "_u", u)
|
|
self.module.register_parameter(self.name + "_v", v)
|
|
self.module.register_parameter(self.name + "_bar", w_bar)
|
|
|
|
def forward(self, *args):
|
|
self._update_u_v()
|
|
return self.module.forward(*args)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
""" Self Attention Layer"""
|
|
|
|
def __init__(self, in_dim, activation=F.relu):
|
|
super().__init__()
|
|
self.chanel_in = in_dim
|
|
self.activation = activation
|
|
|
|
self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False))
|
|
self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False))
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False))
|
|
self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False))
|
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward(self, x):
|
|
m_batchsize, C, width, height = x.size()
|
|
N = height * width
|
|
|
|
theta = self.theta(x)
|
|
phi = self.phi(x)
|
|
phi = self.pool(phi)
|
|
phi = phi.view(m_batchsize, -1, N // 4)
|
|
theta = theta.view(m_batchsize, -1, N)
|
|
theta = theta.permute(0, 2, 1)
|
|
attention = self.softmax(torch.bmm(theta, phi))
|
|
g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4)
|
|
attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height)
|
|
out = self.o_conv(attn_g)
|
|
return self.gamma * out + x
|
|
|
|
|
|
class ConditionalBatchNorm2d(nn.Module):
|
|
def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
|
|
self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
|
|
self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
|
|
|
|
def forward(self, x, y):
|
|
out = self.bn(x)
|
|
gamma = self.gamma_embed(y) + 1
|
|
beta = self.beta_embed(y)
|
|
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
|
return out
|
|
|
|
|
|
class GBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channel,
|
|
out_channel,
|
|
kernel_size=[3, 3],
|
|
padding=1,
|
|
stride=1,
|
|
n_class=None,
|
|
bn=True,
|
|
activation=F.relu,
|
|
upsample=True,
|
|
downsample=False,
|
|
z_dim=148,
|
|
):
|
|
super().__init__()
|
|
|
|
self.conv0 = SpectralNorm(
|
|
nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True)
|
|
)
|
|
self.conv1 = SpectralNorm(
|
|
nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True)
|
|
)
|
|
|
|
self.skip_proj = False
|
|
if in_channel != out_channel or upsample or downsample:
|
|
self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
|
|
self.skip_proj = True
|
|
|
|
self.upsample = upsample
|
|
self.downsample = downsample
|
|
self.activation = activation
|
|
self.bn = bn
|
|
if bn:
|
|
self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim)
|
|
self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim)
|
|
|
|
def forward(self, input, condition=None):
|
|
out = input
|
|
|
|
if self.bn:
|
|
out = self.HyperBN(out, condition)
|
|
out = self.activation(out)
|
|
if self.upsample:
|
|
out = F.interpolate(out, scale_factor=2)
|
|
out = self.conv0(out)
|
|
if self.bn:
|
|
out = self.HyperBN_1(out, condition)
|
|
out = self.activation(out)
|
|
out = self.conv1(out)
|
|
|
|
if self.downsample:
|
|
out = F.avg_pool2d(out, 2)
|
|
|
|
if self.skip_proj:
|
|
skip = input
|
|
if self.upsample:
|
|
skip = F.interpolate(skip, scale_factor=2)
|
|
skip = self.conv_sc(skip)
|
|
if self.downsample:
|
|
skip = F.avg_pool2d(skip, 2)
|
|
else:
|
|
skip = input
|
|
return out + skip
|
|
|
|
|
|
class Generator128(nn.Module):
|
|
def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False):
|
|
super().__init__()
|
|
|
|
self.linear = nn.Linear(n_class, 128, bias=False)
|
|
|
|
if debug:
|
|
chn = 8
|
|
|
|
self.first_view = 16 * chn
|
|
|
|
self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
|
|
|
|
z_dim = code_dim + 28
|
|
|
|
self.GBlock = nn.ModuleList([
|
|
GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
|
|
])
|
|
|
|
self.sa_id = 4
|
|
self.num_split = len(self.GBlock) + 1
|
|
self.attention = SelfAttention(2 * chn)
|
|
self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
|
|
self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
|
|
|
|
def forward(self, input, class_id):
|
|
codes = torch.chunk(input, self.num_split, 1)
|
|
class_emb = self.linear(class_id) # 128
|
|
|
|
out = self.G_linear(codes[0])
|
|
out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
|
|
for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
|
|
if i == self.sa_id:
|
|
out = self.attention(out)
|
|
condition = torch.cat([code, class_emb], 1)
|
|
out = GBlock(out, condition)
|
|
|
|
out = self.ScaledCrossReplicaBN(out)
|
|
out = F.relu(out)
|
|
out = self.colorize(out)
|
|
return torch.tanh(out)
|
|
|
|
|
|
class Generator256(nn.Module):
|
|
def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False):
|
|
super().__init__()
|
|
|
|
self.linear = nn.Linear(n_class, 128, bias=False)
|
|
|
|
if debug:
|
|
chn = 8
|
|
|
|
self.first_view = 16 * chn
|
|
|
|
self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
|
|
|
|
self.GBlock = nn.ModuleList([
|
|
GBlock(16 * chn, 16 * chn, n_class=n_class),
|
|
GBlock(16 * chn, 8 * chn, n_class=n_class),
|
|
GBlock(8 * chn, 8 * chn, n_class=n_class),
|
|
GBlock(8 * chn, 4 * chn, n_class=n_class),
|
|
GBlock(4 * chn, 2 * chn, n_class=n_class),
|
|
GBlock(2 * chn, 1 * chn, n_class=n_class),
|
|
])
|
|
|
|
self.sa_id = 5
|
|
self.num_split = len(self.GBlock) + 1
|
|
self.attention = SelfAttention(2 * chn)
|
|
self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
|
|
self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
|
|
|
|
def forward(self, input, class_id):
|
|
codes = torch.chunk(input, self.num_split, 1)
|
|
class_emb = self.linear(class_id) # 128
|
|
|
|
out = self.G_linear(codes[0])
|
|
out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
|
|
for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
|
|
if i == self.sa_id:
|
|
out = self.attention(out)
|
|
condition = torch.cat([code, class_emb], 1)
|
|
out = GBlock(out, condition)
|
|
|
|
out = self.ScaledCrossReplicaBN(out)
|
|
out = F.relu(out)
|
|
out = self.colorize(out)
|
|
return torch.tanh(out)
|
|
|
|
|
|
class Generator512(nn.Module):
|
|
def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False):
|
|
super().__init__()
|
|
|
|
self.linear = nn.Linear(n_class, 128, bias=False)
|
|
|
|
if debug:
|
|
chn = 8
|
|
|
|
self.first_view = 16 * chn
|
|
|
|
self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn))
|
|
|
|
z_dim = code_dim + 16
|
|
|
|
self.GBlock = nn.ModuleList([
|
|
GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
|
|
GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
|
|
])
|
|
|
|
self.sa_id = 4
|
|
self.num_split = len(self.GBlock) + 1
|
|
self.attention = SelfAttention(4 * chn)
|
|
self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn)
|
|
self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
|
|
|
|
def forward(self, input, class_id):
|
|
codes = torch.chunk(input, self.num_split, 1)
|
|
class_emb = self.linear(class_id) # 128
|
|
|
|
out = self.G_linear(codes[0])
|
|
out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
|
|
for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
|
|
if i == self.sa_id:
|
|
out = self.attention(out)
|
|
condition = torch.cat([code, class_emb], 1)
|
|
out = GBlock(out, condition)
|
|
|
|
out = self.ScaledCrossReplicaBN(out)
|
|
out = F.relu(out)
|
|
out = self.colorize(out)
|
|
return torch.tanh(out)
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, n_class=1000, chn=96, debug=False):
|
|
super().__init__()
|
|
|
|
def conv(in_channel, out_channel, downsample=True):
|
|
return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample)
|
|
|
|
if debug:
|
|
chn = 8
|
|
self.debug = debug
|
|
|
|
self.pre_conv = nn.Sequential(
|
|
SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)),
|
|
nn.ReLU(),
|
|
SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)),
|
|
nn.AvgPool2d(2),
|
|
)
|
|
self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1))
|
|
|
|
self.conv = nn.Sequential(
|
|
conv(1 * chn, 1 * chn, downsample=True),
|
|
conv(1 * chn, 2 * chn, downsample=True),
|
|
SelfAttention(2 * chn),
|
|
conv(2 * chn, 2 * chn, downsample=True),
|
|
conv(2 * chn, 4 * chn, downsample=True),
|
|
conv(4 * chn, 8 * chn, downsample=True),
|
|
conv(8 * chn, 8 * chn, downsample=True),
|
|
conv(8 * chn, 16 * chn, downsample=True),
|
|
conv(16 * chn, 16 * chn, downsample=False),
|
|
)
|
|
|
|
self.linear = SpectralNorm(nn.Linear(16 * chn, 1))
|
|
|
|
self.embed = nn.Embedding(n_class, 16 * chn)
|
|
self.embed.weight.data.uniform_(-0.1, 0.1)
|
|
self.embed = SpectralNorm(self.embed)
|
|
|
|
def forward(self, input, class_id):
|
|
|
|
out = self.pre_conv(input)
|
|
out += self.pre_skip(F.avg_pool2d(input, 2))
|
|
out = self.conv(out)
|
|
out = F.relu(out)
|
|
out = out.view(out.size(0), out.size(1), -1)
|
|
out = out.sum(2)
|
|
out_linear = self.linear(out).squeeze(1)
|
|
embed = self.embed(class_id)
|
|
|
|
prod = (out * embed).sum(1)
|
|
|
|
return out_linear + prod
|