Tan_pytorch_segmentation/pytorch_segmentation/PV_Model/1.py

94 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

class GlobalAttention(nn.Module):
def __init__(self,
dim=256,
num_heads=16,
qkv_bias=False,
window_size=8,
relative_pos_embedding=True
):
super().__init__()
self.num_heads = num_heads # 初始化注意力头的数量
head_dim = dim // self.num_heads # 计算每个注意力头的特征维度
self.scale = head_dim ** -0.5 # 计算缩放因子,用于注意力计算中的点积。
self.ws = window_size # 初始化局部窗口的大小。
self.qkv = Conv(dim, 3 * dim, kernel_size=1, bias=qkv_bias) # 初始化一个卷积层用于生成Query、Key和Value。
self.proj = SeparableConvBN(dim, dim, kernel_size=window_size) # 初始化一个可分离卷积层,用于投影输出。
self.attn_x = nn.AvgPool2d(kernel_size=(window_size, 1), stride=1,
padding=(window_size // 2 - 1, 0)) # 初始化水平方向的平均池化层,用于整合全局信息。
self.attn_y = nn.AvgPool2d(kernel_size=(1, window_size), stride=1,
padding=(0, window_size // 2 - 1)) # 初始化垂直方向的平均池化层,用于整合全局信息。
self.relative_pos_embedding = relative_pos_embedding
# 初始化是否使用相对位置嵌入的标志。
if self.relative_pos_embedding: # 如果使用了相对位置嵌入,会定义一个相对位置偏置表
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
coords_h = torch.arange(self.ws) # 创建一个包含窗口大小ws内所有水平坐标的张量。
coords_w = torch.arange(self.ws) # 创建一个包含窗口大小ws内所有垂直坐标的张量。
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww #使用meshgrid函数创建一个包含所有水平和垂直坐标的张量。
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww将三维坐标张量展平为一维张量。
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
# 2, Wh*Ww, Wh*Ww 计算所有坐标对之间的相对位置,即每个坐标相对于其他所有坐标的差
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.ws - 1 # shift to start from 0将相对坐标的第一个维度增加ws - 1以确保坐标从0开始。
relative_coords[:, :, 1] += self.ws - 1 # 将相对坐标的第二个维度增加ws - 1以确保坐标从0开始。
relative_coords[:, :, 0] *= 2 * self.ws - 1 # 调整相对坐标的第一个维度,使其范围变为[-2*ws+1, 2*ws-1]。
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww将相对坐标的两个维度合并为一个单一的索引用于访问相对位置偏置表。
self.register_buffer("relative_position_index", relative_position_index) # 将相对位置索引注册为一个缓冲区,以便在模型训练过程中重复使用。
trunc_normal_(self.relative_position_bias_table, std=.02)
# 使用trunc_normal_函数初始化相对位置偏置表这是一种常用的初始化技术用于生成服从截断正态分布的参数。
def pad(self, x, ps): # 定义一个函数接受一个特征图x和一个填充大小ps作为参数。
_, _, H, W = x.size() # 获取特征图x的形状并提取高度H和宽度W。
if W % ps != 0: # 如果特征图的宽度W不能被填充大小ps整除则需要进行填充。
x = F.pad(x, (0, ps - W % ps),
mode='reflect') # 使用F.pad函数在特征图的右侧添加填充填充大小为ps - W % ps填充模式为'reflect',这意味着新的像素值将反映原始像素值。
if H % ps != 0: # 如果特征图的高度H不能被填充大小ps整除则需要进行额外的填充。
x = F.pad(x, (0, 0, 0, ps - H % ps),
mode='reflect') # 使用F.pad函数在特征图的下方添加填充填充大小为ps - H % ps填充模式为'reflect'。
return x # 返回填充后的特征图。
def pad_out(self, x): # 定义一个函数接受一个特征图x作为参数。
x = F.pad(x, pad=(0, 1, 0, 1), mode='reflect') # 使用F.pad函数在特征图的右侧和下方添加填充填充大小为1填充模式为'reflect'。
return x
def forward(self, x):
B, C, H, W = x.shape
x = self.pad(x, self.ws) # 填充输入特征图以适应窗口大小。
B, C, Hp, Wp = x.shape # 获取填充后的特征图的形状。
qkv = self.qkv(x) # 生成Query、Key和Value。
q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads,
d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, qkv=3, ws1=self.ws,
ws2=self.ws) # 重新排列Query、Key和Value以适应注意力机制的计算。
dots = (q @ k.transpose(-2, -1)) * self.scale # 计算点积,并应用缩放因子。
# 如果使用了相对位置嵌入,将相对位置偏置加到点积上。
if self.relative_pos_embedding: # 如果启用了相对位置嵌入
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.ws * self.ws, self.ws * self.ws, -1) # Wh*Ww,Wh*Ww,nH获取相对位置偏置表并根据相对位置索引进行调整。
relative_position_bias = relative_position_bias.permute(2, 0,
1).contiguous() # nH, Wh*Ww, Wh*Ww重新排列相对位置偏置以便与点积的形状匹配
dots += relative_position_bias.unsqueeze(0) # 将相对位置偏置加到点积上。
attn = dots.softmax(dim=-1) # 应用softmax函数计算注意力权重。
attn = attn @ v # 注意力权重应用于Value。
attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads,
d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, ws1=self.ws, ws2=self.ws)
attn = attn[:, :, :H, :W] # 裁剪注意力权重,使其与原始输入特征图的形状匹配。
out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \
self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect'))
out = self.pad_out(out) # 添加额外的填充,以适应输出特征图的尺寸。
out = self.proj(out)
# print(out.size())
out = out[:, :, :H, :W]
return out