import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat # relative positional embedding def to(x): return {'device': x.device, 'dtype': x.dtype} def pair(x): return (x, x) if not isinstance(x, tuple) else x def expand_dim(t, dim, k): t = t.unsqueeze(dim=dim) expand_shape = [-1] * len(t.shape) expand_shape[dim] = k return t.expand(*expand_shape) def rel_to_abs(x): b, l, m = x.shape r = (m + 1) // 2 col_pad = torch.zeros((b, l, 1), **to(x)) x = torch.cat((x, col_pad), dim=2) flat_x = rearrange(x, 'b l c -> b (l c)') flat_pad = torch.zeros((b, m - l), **to(x)) flat_x_padded = torch.cat((flat_x, flat_pad), dim=1) final_x = flat_x_padded.reshape(b, l + 1, m) final_x = final_x[:, :l, -r:] return final_x def relative_logits_1d(q, rel_k): b, h, w, _ = q.shape r = (rel_k.shape[0] + 1) // 2 logits = einsum('b x y d, r d -> b x y r', q, rel_k) logits = rearrange(logits, 'b x y r -> (b x) y r') logits = rel_to_abs(logits) logits = logits.reshape(b, h, w, r) logits = expand_dim(logits, dim=2, k=r) return logits class RelPosEmb(nn.Module): def __init__( self, block_size, rel_size, dim_head ): super().__init__() height = width = rel_size scale = dim_head ** -0.5 self.block_size = block_size self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale) def forward(self, q): block = self.block_size q = rearrange(q, 'b (x y) c -> b x y c', x=block) rel_logits_w = relative_logits_1d(q, self.rel_width) rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)') q = rearrange(q, 'b x y d -> b y x d') rel_logits_h = relative_logits_1d(q, self.rel_height) rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)') return rel_logits_w + rel_logits_h # classes class HaloAttention(nn.Module): def __init__( self, *, dim, block_size, halo_size, dim_head=64, heads=8 ): super().__init__() assert halo_size > 0, 'halo size must be greater than 0' self.dim = dim self.heads = heads self.scale = dim_head ** -0.5 self.block_size = block_size self.halo_size = halo_size inner_dim = dim_head * heads self.rel_pos_emb = RelPosEmb( block_size=block_size, rel_size=block_size + (halo_size * 2), dim_head=dim_head ) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim) def forward(self, x): b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size' assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})' # get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=block, p2=block) kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo) kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c=c) # derive queries, keys, values q = self.to_q(q_inp) k, v = self.to_kv(kv_inp).chunk(2, dim=-1) # split heads q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v)) # scale q *= self.scale # attention sim = einsum('b i d, b j d -> b i j', q, k) # add relative positional bias sim += self.rel_pos_emb(q) # mask out padding (in the paper, they claim to not need masks, but what about padding?) mask = torch.ones(1, 1, h, w, device=device) mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo) mask = repeat(mask, '() j i -> (b i h) () j', b=b, h=heads) mask = mask.bool() max_neg_value = -torch.finfo(sim.dtype).max sim.masked_fill_(mask, max_neg_value) # attention attn = sim.softmax(dim=-1) # aggregate out = einsum('b i j, b j d -> b i d', attn, v) # merge and combine heads out = rearrange(out, '(b h) n d -> b n (h d)', h=heads) out = self.to_out(out) # merge blocks back to original feature map out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b=b, h=(h // block), w=(w // block), p1=block, p2=block) return out # 输入 N C H W, 输出 N C H W if __name__ == '__main__': block = HaloAttention(dim=512, block_size=2, halo_size=1, ).cuda() input = torch.rand(1, 512, 64, 64).cuda() output = block(input) print(input.size(), output.size())