# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Type from .common import LayerNorm2d, MLPBlock # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa # 这个代码定义了 SAM 的图像编码器 ImageEncoderViT。它包含以下主要部分: # 1. patch_embed: 这是 ViT 的 patch embedding 层,用于将输入图像划分为 patch,并获得 patch 的 embedding。 # 2. pos_embed: 这是 ViT的绝对位置 embedding,用于为每个patch提供位置信息。 # 3. blocks: 这是 ViT 的 transformer encoder 块的列表,每个块包含多头自注意力层和前馈神经网络。 # 4. neck: 这是图像编码器的“颈部”,包含几个卷积层和 LayerNorm 层,用于从 transformer encoder 块的输出中提取特征。 # 5. forward(): 这是图像编码器的前向传播过程。首先通过 patch_embed 层获得 patch embedding, 然后加上 pos_embed。 # 接着,patch embedding通过transformer encoder块。最后, neck 层从 transformer encoder 块的输出中提取特征。 # 所以,这个 ImageEncoderViT 类定义了 SAM 的图像编码器,它基于 ViT,包含 patch embedding、位置 embedding、 # transformer encoder块以及 neck, 可以从输入图像中提取特征。 # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: super().__init__() self.img_size = img_size self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x # 这个 Block 类实现了 transformer block, 可以选择使用全局注意力或局部窗口注意力,同时包含残差连接。它包含: # __init__方法: # 1. 输入参数: # - dim: 输入通道数 # - num_heads: 注意力头数 # - mlp_ratio: mlp 隐藏层与输入 embedding 维度的比例 # - qkv_bias: 是否为 query、key、value 添加偏置 # - norm_layer: 归一化层 # - act_layer: 激活层 # - use_rel_pos: 是否使用相对位置 embedding # - rel_pos_zero_init: 是否将相对位置 embedding 初始化为 0 # - window_size: 窗口注意力的窗口大小,如果为 0 则使用全局注意力 # - input_size: 计算相对位置 embedding 大小所需的输入分辨率 # 2. 实例化第 1 次和第 2 次归一化层 norm1 和 norm2。 # 3. 实例化 Attention 层和 MLPBlock 层。Attention 层的输入大小根据是否使用窗口注意力进行了调整。 # 4. 记录窗口注意力的窗口大小 window_size。 # forward方法: # 1. 提取 shortcut 并对 x 进行第 1 次归一化。 # 2. 如果使用窗口注意力, 则调用 window_partition 对 x 进行窗口划分。 # 3. 将 x 输入 Attention 层。 # 4. 如果使用窗口注意力,则调用 window_unpartition 对 x 进行窗口反划分。 # 5. x = shortcut + x,实现第 1 次残差连接。 # 6. x = x + mlp(norm2(x)),实现第 2 次残差连接和 MLPBlock。 # 7. 返回最终的 x。 # 所以,这个 Block 类实现了带有可选的窗口注意力和双残差连接的transformer block。 # 窗口注意力可以更好地建模局部结构,双残差连接可以提高梯度流动,都是transformer结构的重要改进。 # 这个 Block 类实现了 transformer 的关键组成部分,同时提供了窗口注意力和残差连接等重要变体,可以显著提高其表现力和泛化能力。 class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x # 这个Attention类实现了多头注意力机制,可以加入相对位置 embedding。它包含: # __init__方法: # 1. 输入参数: # - dim: 输入通道数 # - num_heads: 注意力头数 # - qkv_bias: 是否为查询、键、值添加偏置 # - use_rel_pos: 是否使用相对位置 embedding # - rel_pos_zero_init: 是否将相对位置 embedding 初始化为0 # - input_size: 计算相对位置 embedding 大小所需的输入分辨率 # 2. 计算每个注意力头的维度 head_dim。 # 3. 实例化 self.qkv和 输出投影 self.proj。 # 4. 如果使用相对位置 embedding, 则初始化 rel_pos_h 和 rel_pos_w。 # forward方法: # 1. 从输入 x 中提取批次大小 B、高度 H、宽度 W 和通道数 C。 # 2. 计算 qkv,形状为 (3, B, nHead, H * W, C), 包含 query、key 和 value。 # 3. 提取 q、 k 和 v, 形状为 (B * nHead, H * W, C)。 # 4. 计算注意力图 attn,形状为 (B * nHead, H * W, H * W)。 # 5. 如果使用相对位置 embedding, 则调用 add_decomposed_rel_pos 函数将其加入 attn。 # 6. 对 attn 进行 softmax 归一化。 # 7. 计算输出 x , (attn @ v), 形状为 (B, nHead, H, W, C), 然后合并注意力头, 形状为(B, H, W, C)。 # 8. 对 x 进行投影, 返回最终的输出。 # 所以,这个 Attention 类实现了带有相对位置 embedding 的多头注意力机制。 # 它可以高效地建模图像和视频等二维结构数据,是 transformer 在这些领域得到广泛应用的关键。 # 这个 Attention 类提供了相对位置 embedding 和多头注意力机制的实现, # 是理解 transformer 在图像和视频建模中的重要组成部分。 class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x # 这个 window_partition 函数的作用是将输入张量划分为非重叠的窗口。它包含: # 1. 输入参数: # - x: 输入的张量,形状为 [B, H, W, C] # - window_size: 窗口大小 # 2. 首先计算输入需要 padding 的高度和宽度,将x进行padding。 # 3. 然后将 x 的形状变化为 [B, Hp//window_size, window_size, Wp//window_size, window_size, C], # 表示将图像划分为 Hp//window_size * Wp//window_size 个 window_size * window_size 的 patch。 # 4. 最后,通过 permute 和 view 操作,得到 windows 的形状为 [B * num_windows, window_size, window_size, C], # 表示将所有 patch 打平, num_windows 是 patch 的总数 # 5. 返回windows和原来的高度和宽度(包含padding)Hp和Wp。 # 所以,这个 window_partition 函数的作用是,将输入的图像划分为 window_size * window_size 的 patch, # 并将所有的 patch 打平, 输出可以输入到 transformer encoder 中的 token 序列。 # 这个函数实现了将二维图像转化为一维 token 序列的过程,是 transformer 用于处理图像的一个关键步骤。 # 通过这个函数,图像可以被 transformer encoder 所处理,就像处理文本序列一样。 def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) # 这个 window_unpartition 函数的作用是将 window_partition 函数的输出进行反划分, 恢复成原始的图像形状。它包含: # 1. 输入参数: # - windows: window_partition的输出,形状为 [B * num_windows, window_size, window_size, C] # - window_size: 窗口大小 # - pad_hw: padding后的高度和宽度 (Hp, Wp) # - hw: padding前的原始高度和宽度 (H, W) # 2. 首先根据窗口大小和 padding 后的 hw 计算原始的 batch_size B。 # 3. 然后将 windows 的形状变回 [B, Hp//window_size, Wp//window_size, window_size, window_size, C], 表示每个patch的位置。 # 4. 接着通过permute和view操作,得到x的形状为 [B, Hp, Wp, C], 恢复成图像的形状。 # 5. 最后,如果进行了padding,则截取x到原始的高度H和宽度W。 # 6. 返回恢复后的图像x。 # 所以,这个 window_unpartition 函数的作用是将通过 window_partition 函数得到的 patch 序列恢复成原始的图像。 # 它实现了从一维 patch token 序列到二维图像的反过程。 # 这个函数与 window_partition 函数相反,使得 transformer 能够最终从 patch token 序列恢复成图像,完成对图像的建模。 # 总的来说,这个 window_unpartition 函数实现了从 patch token 序列恢复成原始图像的过程,与 window_partition 函数相对应, # 是使得 transformer 可以处理图像的另一个关键步骤 def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x # 这个 get_rel_pos 函数的作用是根据 query 和 key 的相对位置获取相对位置 embedding。它包含: # 1. 输入参数: # - q_size: query 的大小 # - k_size: key 的大小 # - rel_pos: 相对位置 embedding, 形状为[L, C] # 2. 首先计算最大的相对距离 max_rel_dist, 它等于 query 和 key 大小的 2 倍减 1。 # 3. 如果相对位置 embedding 的长度小于 max_rel_dist, 则通过线性插值将其调整到 max_rel_dist 的长度。 # 4. 如果 q_size 和 k_size 不同, 则将 q_size 和 k_size 的坐标按比例缩放,使它们之间的相对距离保持不变。 # 5. 根据调整后的 q_size 和 k_size 坐标计算相对坐标 relative_coords。 # 6. 根据 relative_coords 从 rel_pos_resized 中提取相对位置 embedding。 # 7. 返回提取出的相对位置 embedding。 # 所以,这个 get_rel_pos 函数的主要作用是,当 query 和 key 的大小不同时,根据它们的相对位置关系提取相应的相对位置 embedding。 # 它实现了相对位置 embedding 的可变长度和可缩放性。 # 这个函数使得相对位置 embedding 可以用于 query 和 key 大小不同的 attention 中,是相对位置表示的一个关键步骤。 # 总的来说,这个 get_rel_pos 函数实现了根据 query 和 key 的相对位置关系提取相应相对位置 embedding 的过程。 # 它提供了相对位置 embedding 的可变长度和可缩放性,使其可以支持不同的 query 和 key 大小,从而应用到更加灵活的 attention 机制中。 def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] # 这个 add_decomposed_rel_pos 函数的作用是根据 query q 和 key k 的空间尺寸, 添加分解的相对位置 embedding 到注意力图 attn 中。它包含: # 1. 输入参数: # - attn: 注意力图,形状为 [B, q_h * q_w, k_h * k_w] # - q: 查询 q,形状为 [B, q_h * q_w, C] # - rel_pos_h: 高度轴的相对位置 embedding, 形状为[Lh, C] # - rel_pos_w: 宽度轴的相对位置 embedding, 形状为[Lw, C] # - q_size: 查询 q的空间尺寸 (q_h, q_w) # - k_size: 键 k的空间尺寸 (k_h, k_w) # 2. 从 q_size 和 k_size 中提取高度 q_h、宽度 q_w 以及高度 k_h、宽度 k_w。 # 3. 调用 get_rel_pos 函数获取高度轴 Rh 和宽度轴 Rw 的相对位置 embedding。 # 4. 重塑 q 为 [B, q_h, q_w, C]。 # 5. 计算高度轴 rel_h 和宽度轴 rel_w 的相对位置图, 形状为 [B, q_h, q_w, k_h] 和 [B, q_h, q_w, k_w]。 # 6. 将 attn 的形状变为 [B, q_h, q_w, k_h, k_w], 并加上 rel_h 和 rel_w。 # 7. 将 attn 的形状变回 [B, q_h * q_w, k_h * k_w]。 # 8. 返回加了相对位置 embedding 的 attn。 def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn # 这个 PatchEmbed 类定义了 ViT 的 patch embedding 层。它包含: # 1. __init__: 初始化,设置卷积层的 kernel size、stride、padding以 及输入通道数和 embedding 维度。 # 2. proj: 这是一个卷积层,用于将输入图像划分为 patch, 并获得每个 patch 的 embedding。 # 3. forward: 前向传播过程。首先通过 proj 卷积层获得 patch embedding ,然后将维度从 [B, C, H, W] 转置成 [B, H, W, C]。 class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, ) -> None: super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x