import torch from torch import nn from operator import itemgetter # from axial_attention.reversible import ReversibleSequence from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html class Deterministic(nn.Module): def __init__(self, net): super().__init__() self.net = net self.cpu_state = None self.cuda_in_fwd = None self.gpu_devices = None self.gpu_states = None def record_rng(self, *args): self.cpu_state = torch.get_rng_state() if torch.cuda._initialized: self.cuda_in_fwd = True self.gpu_devices, self.gpu_states = get_device_states(*args) def forward(self, *args, record_rng=False, set_rng=False, **kwargs): if record_rng: self.record_rng(*args) if not set_rng: return self.net(*args, **kwargs) rng_devices = [] if self.cuda_in_fwd: rng_devices = self.gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=True): torch.set_rng_state(self.cpu_state) if self.cuda_in_fwd: set_device_states(self.gpu_devices, self.gpu_states) return self.net(*args, **kwargs) # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py # once multi-GPU is confirmed working, refactor and send PR back to source class ReversibleBlock(nn.Module): def __init__(self, f, g): super().__init__() self.f = Deterministic(f) self.g = Deterministic(g) def forward(self, x, f_args={}, g_args={}): x1, x2 = torch.chunk(x, 2, dim=1) y1, y2 = None, None with torch.no_grad(): y1 = x1 + self.f(x2, record_rng=self.training, **f_args) y2 = x2 + self.g(y1, record_rng=self.training, **g_args) return torch.cat([y1, y2], dim=1) def backward_pass(self, y, dy, f_args={}, g_args={}): y1, y2 = torch.chunk(y, 2, dim=1) del y dy1, dy2 = torch.chunk(dy, 2, dim=1) del dy with torch.enable_grad(): y1.requires_grad = True gy1 = self.g(y1, set_rng=True, **g_args) torch.autograd.backward(gy1, dy2) with torch.no_grad(): x2 = y2 - gy1 del y2, gy1 dx1 = dy1 + y1.grad del dy1 y1.grad = None with torch.enable_grad(): x2.requires_grad = True fx2 = self.f(x2, set_rng=True, **f_args) torch.autograd.backward(fx2, dx1, retain_graph=True) with torch.no_grad(): x1 = y1 - fx2 del y1, fx2 dx2 = dy2 + x2.grad del dy2 x2.grad = None x = torch.cat([x1, x2.detach()], dim=1) dx = torch.cat([dx1, dx2], dim=1) return x, dx class IrreversibleBlock(nn.Module): def __init__(self, f, g): super().__init__() self.f = f self.g = g def forward(self, x, f_args, g_args): x1, x2 = torch.chunk(x, 2, dim=1) y1 = x1 + self.f(x2, **f_args) y2 = x2 + self.g(y1, **g_args) return torch.cat([y1, y2], dim=1) class _ReversibleFunction(Function): @staticmethod def forward(ctx, x, blocks, kwargs): ctx.kwargs = kwargs for block in blocks: x = block(x, **kwargs) ctx.y = x.detach() ctx.blocks = blocks return x @staticmethod def backward(ctx, dy): y = ctx.y kwargs = ctx.kwargs for block in ctx.blocks[::-1]: y, dy = block.backward_pass(y, dy, **kwargs) return dy, None, None class ReversibleSequence(nn.Module): def __init__(self, blocks, ): super().__init__() self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks]) def forward(self, x, arg_route=(True, True), **kwargs): f_args, g_args = map(lambda route: kwargs if route else {}, arg_route) block_kwargs = {'f_args': f_args, 'g_args': g_args} x = torch.cat((x, x), dim=1) x = _ReversibleFunction.apply(x, self.blocks, block_kwargs) return torch.stack(x.chunk(2, dim=1)).mean(dim=0) # helper functions def exists(val): return val is not None def map_el_ind(arr, ind): return list(map(itemgetter(ind), arr)) def sort_and_return_indices(arr): indices = [ind for ind in range(len(arr))] arr = zip(arr, indices) arr = sorted(arr) return map_el_ind(arr, 0), map_el_ind(arr, 1) # calculates the permutation to bring the input tensor to something attend-able # also calculates the inverse permutation to bring the tensor back to its original shape def calculate_permutations(num_dimensions, emb_dim): total_dimensions = num_dimensions + 2 emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions) axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim] permutations = [] for axial_dim in axial_dims: last_two_dims = [axial_dim, emb_dim] dims_rest = set(range(0, total_dimensions)) - set(last_two_dims) permutation = [*dims_rest, *last_two_dims] permutations.append(permutation) return permutations # helper classes class ChanLayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt() mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (std + self.eps) * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x): x = self.norm(x) return self.fn(x) class Sequential(nn.Module): def __init__(self, blocks): super().__init__() self.blocks = blocks def forward(self, x): for f, g in self.blocks: x = x + f(x) x = x + g(x) return x class PermuteToFrom(nn.Module): def __init__(self, permutation, fn): super().__init__() self.fn = fn _, inv_permutation = sort_and_return_indices(permutation) self.permutation = permutation self.inv_permutation = inv_permutation def forward(self, x, **kwargs): axial = x.permute(*self.permutation).contiguous() shape = axial.shape *_, t, d = shape # merge all but axial dimension axial = axial.reshape(-1, t, d) # attention axial = self.fn(axial, **kwargs) # restore to original shape and permutation axial = axial.reshape(*shape) axial = axial.permute(*self.inv_permutation).contiguous() return axial # axial pos emb class AxialPositionalEmbedding(nn.Module): def __init__(self, dim, shape, emb_dim_index=1): super().__init__() parameters = [] total_dimensions = len(shape) + 2 ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index] self.num_axials = len(shape) for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)): shape = [1] * total_dimensions shape[emb_dim_index] = dim shape[axial_dim_index] = axial_dim parameter = nn.Parameter(torch.randn(*shape)) setattr(self, f'param_{i}', parameter) def forward(self, x): for i in range(self.num_axials): x = x + getattr(self, f'param_{i}') return x # attention class SelfAttention(nn.Module): def __init__(self, dim, heads, dim_heads=None): super().__init__() self.dim_heads = (dim // heads) if dim_heads is None else dim_heads dim_hidden = self.dim_heads * heads self.heads = heads self.to_q = nn.Linear(dim, dim_hidden, bias=False) self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False) self.to_out = nn.Linear(dim_hidden, dim) def forward(self, x, kv=None): kv = x if kv is None else kv q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1)) b, t, d, h, e = *q.shape, self.heads, self.dim_heads merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e) q, k, v = map(merge_heads, (q, k, v)) dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5) dots = dots.softmax(dim=-1) out = torch.einsum('bij,bje->bie', dots, v) out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d) out = self.to_out(out) return out # axial attention class class AxialAttention(nn.Module): def __init__(self, dim, num_dimensions=2, heads=8, dim_heads=None, dim_index=-1, sum_axial_out=True): assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads' super().__init__() self.dim = dim self.total_dimensions = num_dimensions + 2 self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions) attentions = [] for permutation in calculate_permutations(num_dimensions, dim_index): attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads))) self.axial_attentions = nn.ModuleList(attentions) self.sum_axial_out = sum_axial_out def forward(self, x): assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions' assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension' if self.sum_axial_out: return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions)) out = x for axial_attn in self.axial_attentions: out = axial_attn(out) return out # axial image transformer class AxialImageTransformer(nn.Module): def __init__(self, dim, depth, heads=8, dim_heads=None, dim_index=1, reversible=True, axial_pos_emb_shape=None): super().__init__() permutations = calculate_permutations(2, dim_index) get_ff = lambda: nn.Sequential( ChanLayerNorm(dim), nn.Conv2d(dim, dim * 4, 3, padding=1), nn.LeakyReLU(inplace=True), nn.Conv2d(dim * 4, dim, 3, padding=1) ) self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists( axial_pos_emb_shape) else nn.Identity() layers = nn.ModuleList([]) for _ in range(depth): attn_functions = nn.ModuleList( [PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in permutations]) conv_functions = nn.ModuleList([get_ff(), get_ff()]) layers.append(attn_functions) layers.append(conv_functions) execute_type = ReversibleSequence if reversible else Sequential self.layers = execute_type(layers) def forward(self, x): x = self.pos_emb(x) return self.layers(x) # 输入 N C HW, 输出 N C H W if __name__ == '__main__': block = AxialImageTransformer( dim=64, depth=12, reversible=True ).cuda() input = torch.rand(1, 64, 64, 64).cuda() output = block(input) print(input.size(), output.size())