from typing import Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor, LongTensor class TopkRouting(nn.Module): """ differentiable topk routing with scaling Args: qk_dim: int, feature dimension of query and key topk: int, the 'topk' qk_scale: int or None, temperature (multiply) of softmax activation with_param: bool, wether inorporate learnable params in routing unit diff_routing: bool, wether make routing differentiable soft_routing: bool, wether make output value multiplied by routing weights """ def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False): super().__init__() self.topk = topk self.qk_dim = qk_dim self.scale = qk_scale or qk_dim ** -0.5 self.diff_routing = diff_routing # TODO: norm layer before/after linear? self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity() # routing activation self.routing_act = nn.Softmax(dim=-1) def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]: """ Args: q, k: (n, p^2, c) tensor Return: r_weight, topk_index: (n, p^2, topk) tensor """ if not self.diff_routing: query, key = query.detach(), key.detach() query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2) topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k) r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k) return r_weight, topk_index class KVGather(nn.Module): def __init__(self, mul_weight='none'): super().__init__() assert mul_weight in ['none', 'soft', 'hard'] self.mul_weight = mul_weight def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor): """ r_idx: (n, p^2, topk) tensor r_weight: (n, p^2, topk) tensor kv: (n, p^2, w^2, c_kq+c_v) Return: (n, p^2, topk, w^2, c_kq+c_v) tensor """ # select kv according to routing index n, p2, w2, c_kv = kv.size() topk = r_idx.size(-1) # print(r_idx.size(), r_weight.size()) # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy dim=2, index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv) ) if self.mul_weight == 'soft': topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv) elif self.mul_weight == 'hard': raise NotImplementedError('differentiable hard routing TBA') # else: #'none' # topk_kv = topk_kv # do nothing return topk_kv class QKVLinear(nn.Module): def __init__(self, dim, qk_dim, bias=True): super().__init__() self.dim = dim self.qk_dim = qk_dim self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias) def forward(self, x): q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1) return q, kv # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1) # return q, k, v class BiLevelRoutingAttention(nn.Module): """ n_win: number of windows in one side (so the actual number of windows is n_win*n_win) kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win. topk: topk for window filtering param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention param_routing: extra linear for routing diff_routing: wether to set routing differentiable soft_routing: wether to multiply soft routing weights """ def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None, kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity', topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3, auto_pad=True): super().__init__() # local attention setting self.dim = dim self.n_win = n_win # Wh, Ww self.num_heads = num_heads self.qk_dim = qk_dim or dim assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!' self.scale = qk_scale or self.qk_dim ** -0.5 ################side_dwconv (i.e. LCE in ShuntedTransformer)########### self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, groups=dim) if side_dwconv > 0 else \ lambda x: torch.zeros_like(x) ################ global routing setting ################# self.topk = topk self.param_routing = param_routing self.diff_routing = diff_routing self.soft_routing = soft_routing # router assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False self.router = TopkRouting(qk_dim=self.qk_dim, qk_scale=self.scale, topk=self.topk, diff_routing=self.diff_routing, param_routing=self.param_routing) if self.soft_routing: # soft routing, always diffrentiable (if no detach) mul_weight = 'soft' elif self.diff_routing: # hard differentiable routing mul_weight = 'hard' else: # hard non-differentiable routing mul_weight = 'none' self.kv_gather = KVGather(mul_weight=mul_weight) # qkv mapping (shared by both global routing and local attention) self.param_attention = param_attention if self.param_attention == 'qkvo': self.qkv = QKVLinear(self.dim, self.qk_dim) self.wo = nn.Linear(dim, dim) elif self.param_attention == 'qkv': self.qkv = QKVLinear(self.dim, self.qk_dim) self.wo = nn.Identity() else: raise ValueError(f'param_attention mode {self.param_attention} is not surpported!') self.kv_downsample_mode = kv_downsample_mode self.kv_per_win = kv_per_win self.kv_downsample_ratio = kv_downsample_ratio self.kv_downsample_kenel = kv_downsample_kernel if self.kv_downsample_mode == 'ada_avgpool': assert self.kv_per_win is not None self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win) elif self.kv_downsample_mode == 'ada_maxpool': assert self.kv_per_win is not None self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win) elif self.kv_downsample_mode == 'maxpool': assert self.kv_downsample_ratio is not None self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() elif self.kv_downsample_mode == 'avgpool': assert self.kv_downsample_ratio is not None self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() elif self.kv_downsample_mode == 'identity': # no kv downsampling self.kv_down = nn.Identity() elif self.kv_downsample_mode == 'fracpool': # assert self.kv_downsample_ratio is not None # assert self.kv_downsample_kenel is not None # TODO: fracpool # 1. kernel size should be input size dependent # 2. there is a random factor, need to avoid independent sampling for k and v raise NotImplementedError('fracpool policy is not implemented yet!') elif kv_downsample_mode == 'conv': # TODO: need to consider the case where k != v so that need two downsample modules raise NotImplementedError('conv policy is not implemented yet!') else: raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!') # softmax for local attention self.attn_act = nn.Softmax(dim=-1) self.auto_pad = auto_pad def forward(self, x, ret_attn_mask=False): """ x: NHWC tensor Return: NHWC tensor """ x = rearrange(x, "n c h w -> n h w c") # NOTE: use padding for semantic segmentation ################################################### if self.auto_pad: N, H_in, W_in, C = x.size() pad_l = pad_t = 0 pad_r = (self.n_win - W_in % self.n_win) % self.n_win pad_b = (self.n_win - H_in % self.n_win) % self.n_win x = F.pad(x, (0, 0, # dim=-1 pad_l, pad_r, # dim=-2 pad_t, pad_b)) # dim=-3 _, H, W, _ = x.size() # padded size else: N, H, W, C = x.size() assert H % self.n_win == 0 and W % self.n_win == 0 # ################################################### # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win) #################qkv projection################### # q: (n, p^2, w, w, c_qk) # kv: (n, p^2, w, w, c_qk+c_v) # NOTE: separte kv if there were memory leak issue caused by gather q, kv = self.qkv(x) # pixel-wise qkv # q_pix: (n, p^2, w^2, c_qk) # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v) q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c') kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w')) kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win) q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean( [2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk) ##################side_dwconv(lepe)################## # NOTE: call contiguous to avoid gradient warning when using ddp lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous()) lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win) ############ gather q dependent k/v ################# r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) # (n, p^2, topk, h_kv*w_kv, c_qk+c_v) k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1) # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk) # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v) ######### do attention as normal #################### k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here? v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m) q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m) # param-free multihead attention attn_weight = ( q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv) attn_weight = self.attn_act(attn_weight) out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c) out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win, h=H // self.n_win, w=W // self.n_win) out = out + lepe # output linear out = self.wo(out) # NOTE: use padding for semantic segmentation # crop padded region if self.auto_pad and (pad_r > 0 or pad_b > 0): out = out[:, :H_in, :W_in, :].contiguous() if ret_attn_mask: return out, r_weight, r_idx, attn_weight else: return rearrange(out, "n h w c -> n c h w") class Attention(nn.Module): """ vanilla attention """ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): """ args: x: NCHW tensor return: NCHW tensor """ _, _, H, W = x.size() x = rearrange(x, 'n c h w -> n (h w) c') ####################################### B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) ####################################### x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W) return x class AttentionLePE(nn.Module): """ vanilla attention """ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, groups=dim) if side_dwconv > 0 else \ lambda x: torch.zeros_like(x) def forward(self, x): """ args: x: NCHW tensor return: NCHW tensor """ _, _, H, W = x.size() x = rearrange(x, 'n c h w -> n (h w) c') ####################################### B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)) lepe = rearrange(lepe, 'n c h w -> n (h w) c') attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = x + lepe x = self.proj(x) x = self.proj_drop(x) ####################################### x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W) return x def _grid2seq(x: Tensor, region_size: Tuple[int], num_heads: int): """ Args: x: BCHW tensor region size: int num_heads: number of attention heads Return: out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim) region_h, region_w: number of regions per col/row """ B, C, H, W = x.size() region_h, region_w = H // region_size[0], W // region_size[1] x = x.view(B, num_heads, C // num_heads, region_h, region_size[0], region_w, region_size[1]) x = torch.einsum('bmdhpwq->bmhwpqd', x).flatten(2, 3).flatten(-3, -2) # (bs, nhead, nregion, reg_size, head_dim) return x, region_h, region_w def _seq2grid(x: Tensor, region_h: int, region_w: int, region_size: Tuple[int]): """ Args: x: (bs, nhead, nregion, reg_size^2, head_dim) Return: x: (bs, C, H, W) """ bs, nhead, nregion, reg_size_square, head_dim = x.size() x = x.view(bs, nhead, region_h, region_w, region_size[0], region_size[1], head_dim) x = torch.einsum('bmhwpqd->bmdhpwq', x).reshape(bs, nhead * head_dim, region_h * region_size[0], region_w * region_size[1]) return x def regional_routing_attention_torch( query: Tensor, key: Tensor, value: Tensor, scale: float, region_graph: LongTensor, region_size: Tuple[int], kv_region_size: Optional[Tuple[int]] = None, auto_pad=True) -> Tensor: """ Args: query, key, value: (B, C, H, W) tensor scale: the scale/temperature for dot product attention region_graph: (B, nhead, h_q*w_q, topk) tensor, topk <= h_k*w_k region_size: region/window size for queries, (rh, rw) key_region_size: optional, if None, key_region_size=region_size auto_pad: required to be true if the input sizes are not divisible by the region_size Return: output: (B, C, H, W) tensor attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix """ kv_region_size = kv_region_size or region_size bs, nhead, q_nregion, topk = region_graph.size() # Auto pad to deal with any input size q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0 if auto_pad: _, _, Hq, Wq = query.size() q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0] q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1] if (q_pad_b > 0 or q_pad_r > 0): query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding _, _, Hk, Wk = key.size() kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0] kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1] if (kv_pad_r > 0 or kv_pad_b > 0): key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding # to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim) query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead) key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead) value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead) # gather key and values. # TODO: is seperate gathering slower than fused one (our old version) ? # torch.gather does not support broadcasting, hence we do it manually bs, nhead, kv_nregion, kv_region_size, head_dim = key.size() broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1). \ expand(-1, -1, -1, -1, kv_region_size, head_dim) key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim). \ expand(-1, -1, query.size(2), -1, -1, -1), dim=3, index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim). \ expand(-1, -1, query.size(2), -1, -1, -1), dim=3, index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) # token-to-token attention # (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size) # -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size) # TODO: mask padding region attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2) attn = torch.softmax(attn, dim=-1) # (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim) # -> (bs, nhead, q_nregion, reg_size, head_dim) output = attn @ value_g.flatten(-3, -2) # to BCHW format output = _seq2grid(output, region_h=q_region_h, region_w=q_region_w, region_size=region_size) # remove paddings if needed if auto_pad and (q_pad_b > 0 or q_pad_r > 0): output = output[:, :, :Hq, :Wq] return output, attn class BiLevelRoutingAttention_nchw(nn.Module): """Bi-Level Routing Attention that takes nchw input Compared to legacy version, this implementation: * removes unused args and components * uses nchw input format to avoid frequent permutation When the size of inputs is not divisible by the region size, there is also a numerical difference than legacy implementation, due to: * different way to pad the input feature map (padding after linear projection) * different pooling behavior (count_include_pad=False) Current implementation is more reasonable, hence we do not keep backward numerical compatiability """ def __init__(self, dim, num_heads=8, n_win=7, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'): super().__init__() # local attention setting self.dim = dim self.num_heads = num_heads assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!' self.head_dim = self.dim // self.num_heads self.scale = qk_scale or self.dim ** -0.5 # NOTE: to be consistent with old models. ################side_dwconv (i.e. LCE in Shunted Transformer)########### self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, groups=dim) if side_dwconv > 0 else \ lambda x: torch.zeros_like(x) ################ regional routing setting ################# self.topk = topk self.n_win = n_win # number of windows per row/col ########################################## self.qkv_linear = nn.Conv2d(self.dim, 3 * self.dim, kernel_size=1) self.output_linear = nn.Conv2d(self.dim, self.dim, kernel_size=1) if attn_backend == 'torch': self.attn_fn = regional_routing_attention_torch else: raise ValueError('CUDA implementation is not available yet. Please stay tuned.') def forward(self, x: Tensor, ret_attn_mask=False): """ Args: x: NCHW tensor, better to be channel_last (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) Return: NCHW tensor """ N, C, H, W = x.size() region_size = (H // self.n_win, W // self.n_win) # STEP 1: linear projection qkv = self.qkv_linear.forward(x) # ncHW q, k, v = qkv.chunk(3, dim=1) # ncHW # STEP 2: region-to-region routing # NOTE: ceil_mode=True, count_include_pad=False = auto padding # NOTE: gradients backward through token-to-token attention. See Appendix A for the intuition. q_r = F.avg_pool2d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) k_r = F.avg_pool2d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # nchw q_r: Tensor = q_r.permute(0, 2, 3, 1).flatten(1, 2) # n(hw)c k_r: Tensor = k_r.flatten(2, 3) # nc(hw) a_r = q_r @ k_r # n(hw)(hw), adj matrix of regional graph _, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(hw)k long tensor idx_r: LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1) # STEP 3: token to token attention (non-parametric function) output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale, region_graph=idx_r, region_size=region_size ) output = output + self.lepe(v) # ncHW output = self.output_linear(output) # ncHW if ret_attn_mask: return output, attn_mat return output # 输入 N C HW, 输出 N C H W if __name__ == '__main__': block = BiLevelRoutingAttention_nchw(64).cuda() input = torch.rand(1, 64, 64, 64).cuda() output = block(input) print(input.size(), output.size())