580 lines
25 KiB
Python
580 lines
25 KiB
Python
from typing import Tuple, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch import Tensor, LongTensor
|
|
|
|
|
|
class TopkRouting(nn.Module):
|
|
"""
|
|
differentiable topk routing with scaling
|
|
Args:
|
|
qk_dim: int, feature dimension of query and key
|
|
topk: int, the 'topk'
|
|
qk_scale: int or None, temperature (multiply) of softmax activation
|
|
with_param: bool, wether inorporate learnable params in routing unit
|
|
diff_routing: bool, wether make routing differentiable
|
|
soft_routing: bool, wether make output value multiplied by routing weights
|
|
"""
|
|
|
|
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
|
|
super().__init__()
|
|
self.topk = topk
|
|
self.qk_dim = qk_dim
|
|
self.scale = qk_scale or qk_dim ** -0.5
|
|
self.diff_routing = diff_routing
|
|
# TODO: norm layer before/after linear?
|
|
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
|
|
# routing activation
|
|
self.routing_act = nn.Softmax(dim=-1)
|
|
|
|
def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]:
|
|
"""
|
|
Args:
|
|
q, k: (n, p^2, c) tensor
|
|
Return:
|
|
r_weight, topk_index: (n, p^2, topk) tensor
|
|
"""
|
|
if not self.diff_routing:
|
|
query, key = query.detach(), key.detach()
|
|
query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
|
|
attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
|
|
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
|
|
r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
|
|
|
|
return r_weight, topk_index
|
|
|
|
|
|
class KVGather(nn.Module):
|
|
def __init__(self, mul_weight='none'):
|
|
super().__init__()
|
|
assert mul_weight in ['none', 'soft', 'hard']
|
|
self.mul_weight = mul_weight
|
|
|
|
def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
|
|
"""
|
|
r_idx: (n, p^2, topk) tensor
|
|
r_weight: (n, p^2, topk) tensor
|
|
kv: (n, p^2, w^2, c_kq+c_v)
|
|
|
|
Return:
|
|
(n, p^2, topk, w^2, c_kq+c_v) tensor
|
|
"""
|
|
# select kv according to routing index
|
|
n, p2, w2, c_kv = kv.size()
|
|
topk = r_idx.size(-1)
|
|
# print(r_idx.size(), r_weight.size())
|
|
# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
|
|
topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
|
|
# (n, p^2, p^2, w^2, c_kv) without mem cpy
|
|
dim=2,
|
|
index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
|
|
# (n, p^2, k, w^2, c_kv)
|
|
)
|
|
|
|
if self.mul_weight == 'soft':
|
|
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
|
|
elif self.mul_weight == 'hard':
|
|
raise NotImplementedError('differentiable hard routing TBA')
|
|
# else: #'none'
|
|
# topk_kv = topk_kv # do nothing
|
|
|
|
return topk_kv
|
|
|
|
|
|
class QKVLinear(nn.Module):
|
|
def __init__(self, dim, qk_dim, bias=True):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.qk_dim = qk_dim
|
|
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
|
|
|
|
def forward(self, x):
|
|
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
|
|
return q, kv
|
|
# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
|
|
# return q, k, v
|
|
|
|
|
|
class BiLevelRoutingAttention(nn.Module):
|
|
"""
|
|
n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
|
|
kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
|
|
topk: topk for window filtering
|
|
param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
|
|
param_routing: extra linear for routing
|
|
diff_routing: wether to set routing differentiable
|
|
soft_routing: wether to multiply soft routing weights
|
|
"""
|
|
|
|
def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
|
|
kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
|
|
topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False,
|
|
side_dwconv=3,
|
|
auto_pad=True):
|
|
super().__init__()
|
|
# local attention setting
|
|
self.dim = dim
|
|
self.n_win = n_win # Wh, Ww
|
|
self.num_heads = num_heads
|
|
self.qk_dim = qk_dim or dim
|
|
assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!'
|
|
self.scale = qk_scale or self.qk_dim ** -0.5
|
|
|
|
################side_dwconv (i.e. LCE in ShuntedTransformer)###########
|
|
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
|
|
groups=dim) if side_dwconv > 0 else \
|
|
lambda x: torch.zeros_like(x)
|
|
|
|
################ global routing setting #################
|
|
self.topk = topk
|
|
self.param_routing = param_routing
|
|
self.diff_routing = diff_routing
|
|
self.soft_routing = soft_routing
|
|
# router
|
|
assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
|
|
self.router = TopkRouting(qk_dim=self.qk_dim,
|
|
qk_scale=self.scale,
|
|
topk=self.topk,
|
|
diff_routing=self.diff_routing,
|
|
param_routing=self.param_routing)
|
|
if self.soft_routing: # soft routing, always diffrentiable (if no detach)
|
|
mul_weight = 'soft'
|
|
elif self.diff_routing: # hard differentiable routing
|
|
mul_weight = 'hard'
|
|
else: # hard non-differentiable routing
|
|
mul_weight = 'none'
|
|
self.kv_gather = KVGather(mul_weight=mul_weight)
|
|
|
|
# qkv mapping (shared by both global routing and local attention)
|
|
self.param_attention = param_attention
|
|
if self.param_attention == 'qkvo':
|
|
self.qkv = QKVLinear(self.dim, self.qk_dim)
|
|
self.wo = nn.Linear(dim, dim)
|
|
elif self.param_attention == 'qkv':
|
|
self.qkv = QKVLinear(self.dim, self.qk_dim)
|
|
self.wo = nn.Identity()
|
|
else:
|
|
raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
|
|
|
|
self.kv_downsample_mode = kv_downsample_mode
|
|
self.kv_per_win = kv_per_win
|
|
self.kv_downsample_ratio = kv_downsample_ratio
|
|
self.kv_downsample_kenel = kv_downsample_kernel
|
|
if self.kv_downsample_mode == 'ada_avgpool':
|
|
assert self.kv_per_win is not None
|
|
self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
|
|
elif self.kv_downsample_mode == 'ada_maxpool':
|
|
assert self.kv_per_win is not None
|
|
self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
|
|
elif self.kv_downsample_mode == 'maxpool':
|
|
assert self.kv_downsample_ratio is not None
|
|
self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
|
|
elif self.kv_downsample_mode == 'avgpool':
|
|
assert self.kv_downsample_ratio is not None
|
|
self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
|
|
elif self.kv_downsample_mode == 'identity': # no kv downsampling
|
|
self.kv_down = nn.Identity()
|
|
elif self.kv_downsample_mode == 'fracpool':
|
|
# assert self.kv_downsample_ratio is not None
|
|
# assert self.kv_downsample_kenel is not None
|
|
# TODO: fracpool
|
|
# 1. kernel size should be input size dependent
|
|
# 2. there is a random factor, need to avoid independent sampling for k and v
|
|
raise NotImplementedError('fracpool policy is not implemented yet!')
|
|
elif kv_downsample_mode == 'conv':
|
|
# TODO: need to consider the case where k != v so that need two downsample modules
|
|
raise NotImplementedError('conv policy is not implemented yet!')
|
|
else:
|
|
raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
|
|
|
|
# softmax for local attention
|
|
self.attn_act = nn.Softmax(dim=-1)
|
|
|
|
self.auto_pad = auto_pad
|
|
|
|
def forward(self, x, ret_attn_mask=False):
|
|
"""
|
|
x: NHWC tensor
|
|
|
|
Return:
|
|
NHWC tensor
|
|
"""
|
|
x = rearrange(x, "n c h w -> n h w c")
|
|
# NOTE: use padding for semantic segmentation
|
|
###################################################
|
|
if self.auto_pad:
|
|
N, H_in, W_in, C = x.size()
|
|
|
|
pad_l = pad_t = 0
|
|
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
|
|
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
|
|
x = F.pad(x, (0, 0, # dim=-1
|
|
pad_l, pad_r, # dim=-2
|
|
pad_t, pad_b)) # dim=-3
|
|
_, H, W, _ = x.size() # padded size
|
|
else:
|
|
N, H, W, C = x.size()
|
|
assert H % self.n_win == 0 and W % self.n_win == 0 #
|
|
###################################################
|
|
|
|
# patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
|
|
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
|
|
|
|
#################qkv projection###################
|
|
# q: (n, p^2, w, w, c_qk)
|
|
# kv: (n, p^2, w, w, c_qk+c_v)
|
|
# NOTE: separte kv if there were memory leak issue caused by gather
|
|
q, kv = self.qkv(x)
|
|
|
|
# pixel-wise qkv
|
|
# q_pix: (n, p^2, w^2, c_qk)
|
|
# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
|
|
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
|
|
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
|
|
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
|
|
|
|
q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean(
|
|
[2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
|
|
|
|
##################side_dwconv(lepe)##################
|
|
# NOTE: call contiguous to avoid gradient warning when using ddp
|
|
lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win,
|
|
i=self.n_win).contiguous())
|
|
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
|
|
|
|
############ gather q dependent k/v #################
|
|
|
|
r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
|
|
|
|
kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) # (n, p^2, topk, h_kv*w_kv, c_qk+c_v)
|
|
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
|
|
# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
|
|
# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
|
|
|
|
######### do attention as normal ####################
|
|
k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)',
|
|
m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
|
|
v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c',
|
|
m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
|
|
q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c',
|
|
m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
|
|
|
|
# param-free multihead attention
|
|
attn_weight = (
|
|
q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
|
|
attn_weight = self.attn_act(attn_weight)
|
|
out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
|
|
out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
|
|
h=H // self.n_win, w=W // self.n_win)
|
|
|
|
out = out + lepe
|
|
# output linear
|
|
out = self.wo(out)
|
|
|
|
# NOTE: use padding for semantic segmentation
|
|
# crop padded region
|
|
if self.auto_pad and (pad_r > 0 or pad_b > 0):
|
|
out = out[:, :H_in, :W_in, :].contiguous()
|
|
|
|
if ret_attn_mask:
|
|
return out, r_weight, r_idx, attn_weight
|
|
else:
|
|
return rearrange(out, "n h w c -> n c h w")
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
vanilla attention
|
|
"""
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
args:
|
|
x: NCHW tensor
|
|
return:
|
|
NCHW tensor
|
|
"""
|
|
_, _, H, W = x.size()
|
|
x = rearrange(x, 'n c h w -> n (h w) c')
|
|
|
|
#######################################
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
#######################################
|
|
|
|
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
|
|
return x
|
|
|
|
|
|
class AttentionLePE(nn.Module):
|
|
"""
|
|
vanilla attention
|
|
"""
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
|
|
groups=dim) if side_dwconv > 0 else \
|
|
lambda x: torch.zeros_like(x)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
args:
|
|
x: NCHW tensor
|
|
return:
|
|
NCHW tensor
|
|
"""
|
|
_, _, H, W = x.size()
|
|
x = rearrange(x, 'n c h w -> n (h w) c')
|
|
|
|
#######################################
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
|
|
lepe = rearrange(lepe, 'n c h w -> n (h w) c')
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
x = x + lepe
|
|
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
#######################################
|
|
|
|
x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
|
|
return x
|
|
|
|
|
|
def _grid2seq(x: Tensor, region_size: Tuple[int], num_heads: int):
|
|
"""
|
|
Args:
|
|
x: BCHW tensor
|
|
region size: int
|
|
num_heads: number of attention heads
|
|
Return:
|
|
out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim)
|
|
region_h, region_w: number of regions per col/row
|
|
"""
|
|
B, C, H, W = x.size()
|
|
region_h, region_w = H // region_size[0], W // region_size[1]
|
|
x = x.view(B, num_heads, C // num_heads, region_h, region_size[0], region_w, region_size[1])
|
|
x = torch.einsum('bmdhpwq->bmhwpqd', x).flatten(2, 3).flatten(-3, -2) # (bs, nhead, nregion, reg_size, head_dim)
|
|
return x, region_h, region_w
|
|
|
|
|
|
def _seq2grid(x: Tensor, region_h: int, region_w: int, region_size: Tuple[int]):
|
|
"""
|
|
Args:
|
|
x: (bs, nhead, nregion, reg_size^2, head_dim)
|
|
Return:
|
|
x: (bs, C, H, W)
|
|
"""
|
|
bs, nhead, nregion, reg_size_square, head_dim = x.size()
|
|
x = x.view(bs, nhead, region_h, region_w, region_size[0], region_size[1], head_dim)
|
|
x = torch.einsum('bmhwpqd->bmdhpwq', x).reshape(bs, nhead * head_dim,
|
|
region_h * region_size[0], region_w * region_size[1])
|
|
return x
|
|
|
|
|
|
def regional_routing_attention_torch(
|
|
query: Tensor, key: Tensor, value: Tensor, scale: float,
|
|
region_graph: LongTensor, region_size: Tuple[int],
|
|
kv_region_size: Optional[Tuple[int]] = None,
|
|
auto_pad=True) -> Tensor:
|
|
"""
|
|
Args:
|
|
query, key, value: (B, C, H, W) tensor
|
|
scale: the scale/temperature for dot product attention
|
|
region_graph: (B, nhead, h_q*w_q, topk) tensor, topk <= h_k*w_k
|
|
region_size: region/window size for queries, (rh, rw)
|
|
key_region_size: optional, if None, key_region_size=region_size
|
|
auto_pad: required to be true if the input sizes are not divisible by the region_size
|
|
Return:
|
|
output: (B, C, H, W) tensor
|
|
attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix
|
|
"""
|
|
kv_region_size = kv_region_size or region_size
|
|
bs, nhead, q_nregion, topk = region_graph.size()
|
|
|
|
# Auto pad to deal with any input size
|
|
q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0
|
|
if auto_pad:
|
|
_, _, Hq, Wq = query.size()
|
|
q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0]
|
|
q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1]
|
|
if (q_pad_b > 0 or q_pad_r > 0):
|
|
query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding
|
|
|
|
_, _, Hk, Wk = key.size()
|
|
kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0]
|
|
kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1]
|
|
if (kv_pad_r > 0 or kv_pad_b > 0):
|
|
key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
|
|
value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
|
|
|
|
# to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim)
|
|
query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead)
|
|
key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead)
|
|
value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead)
|
|
|
|
# gather key and values.
|
|
# TODO: is seperate gathering slower than fused one (our old version) ?
|
|
# torch.gather does not support broadcasting, hence we do it manually
|
|
bs, nhead, kv_nregion, kv_region_size, head_dim = key.size()
|
|
broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1). \
|
|
expand(-1, -1, -1, -1, kv_region_size, head_dim)
|
|
key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim). \
|
|
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
|
|
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
|
|
value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim). \
|
|
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
|
|
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
|
|
|
|
# token-to-token attention
|
|
# (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size)
|
|
# -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size)
|
|
# TODO: mask padding region
|
|
attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2)
|
|
attn = torch.softmax(attn, dim=-1)
|
|
# (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim)
|
|
# -> (bs, nhead, q_nregion, reg_size, head_dim)
|
|
output = attn @ value_g.flatten(-3, -2)
|
|
|
|
# to BCHW format
|
|
output = _seq2grid(output, region_h=q_region_h, region_w=q_region_w, region_size=region_size)
|
|
|
|
# remove paddings if needed
|
|
if auto_pad and (q_pad_b > 0 or q_pad_r > 0):
|
|
output = output[:, :, :Hq, :Wq]
|
|
|
|
return output, attn
|
|
|
|
|
|
class BiLevelRoutingAttention_nchw(nn.Module):
|
|
"""Bi-Level Routing Attention that takes nchw input
|
|
|
|
Compared to legacy version, this implementation:
|
|
* removes unused args and components
|
|
* uses nchw input format to avoid frequent permutation
|
|
|
|
When the size of inputs is not divisible by the region size, there is also a numerical difference
|
|
than legacy implementation, due to:
|
|
* different way to pad the input feature map (padding after linear projection)
|
|
* different pooling behavior (count_include_pad=False)
|
|
|
|
Current implementation is more reasonable, hence we do not keep backward numerical compatiability
|
|
"""
|
|
|
|
def __init__(self, dim, num_heads=8, n_win=7, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False,
|
|
attn_backend='torch'):
|
|
super().__init__()
|
|
# local attention setting
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!'
|
|
self.head_dim = self.dim // self.num_heads
|
|
self.scale = qk_scale or self.dim ** -0.5 # NOTE: to be consistent with old models.
|
|
|
|
################side_dwconv (i.e. LCE in Shunted Transformer)###########
|
|
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
|
|
groups=dim) if side_dwconv > 0 else \
|
|
lambda x: torch.zeros_like(x)
|
|
|
|
################ regional routing setting #################
|
|
self.topk = topk
|
|
self.n_win = n_win # number of windows per row/col
|
|
|
|
##########################################
|
|
|
|
self.qkv_linear = nn.Conv2d(self.dim, 3 * self.dim, kernel_size=1)
|
|
self.output_linear = nn.Conv2d(self.dim, self.dim, kernel_size=1)
|
|
|
|
if attn_backend == 'torch':
|
|
self.attn_fn = regional_routing_attention_torch
|
|
else:
|
|
raise ValueError('CUDA implementation is not available yet. Please stay tuned.')
|
|
|
|
def forward(self, x: Tensor, ret_attn_mask=False):
|
|
"""
|
|
Args:
|
|
x: NCHW tensor, better to be channel_last (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
|
|
Return:
|
|
NCHW tensor
|
|
"""
|
|
N, C, H, W = x.size()
|
|
region_size = (H // self.n_win, W // self.n_win)
|
|
|
|
# STEP 1: linear projection
|
|
qkv = self.qkv_linear.forward(x) # ncHW
|
|
q, k, v = qkv.chunk(3, dim=1) # ncHW
|
|
|
|
# STEP 2: region-to-region routing
|
|
# NOTE: ceil_mode=True, count_include_pad=False = auto padding
|
|
# NOTE: gradients backward through token-to-token attention. See Appendix A for the intuition.
|
|
q_r = F.avg_pool2d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False)
|
|
k_r = F.avg_pool2d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # nchw
|
|
q_r: Tensor = q_r.permute(0, 2, 3, 1).flatten(1, 2) # n(hw)c
|
|
k_r: Tensor = k_r.flatten(2, 3) # nc(hw)
|
|
a_r = q_r @ k_r # n(hw)(hw), adj matrix of regional graph
|
|
_, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(hw)k long tensor
|
|
idx_r: LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
|
|
|
|
# STEP 3: token to token attention (non-parametric function)
|
|
output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale,
|
|
region_graph=idx_r, region_size=region_size
|
|
)
|
|
|
|
output = output + self.lepe(v) # ncHW
|
|
output = self.output_linear(output) # ncHW
|
|
|
|
if ret_attn_mask:
|
|
return output, attn_mat
|
|
|
|
return output
|
|
|
|
|
|
# 输入 N C HW, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
block = BiLevelRoutingAttention_nchw(64).cuda()
|
|
input = torch.rand(1, 64, 64, 64).cuda()
|
|
output = block(input)
|
|
print(input.size(), output.size())
|