121 lines
5.0 KiB
Python
121 lines
5.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def position(H, W, is_cuda=True):
|
|
if is_cuda:
|
|
loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
|
|
loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
|
|
else:
|
|
loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
|
|
loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
|
|
loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
|
|
return loc
|
|
|
|
|
|
def stride(x, stride):
|
|
b, c, h, w = x.shape
|
|
return x[:, :, ::stride, ::stride]
|
|
|
|
|
|
def init_rate_half(tensor):
|
|
if tensor is not None:
|
|
tensor.data.fill_(0.5)
|
|
|
|
|
|
def init_rate_0(tensor):
|
|
if tensor is not None:
|
|
tensor.data.fill_(0.)
|
|
|
|
|
|
class ACmix(nn.Module):
|
|
def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
|
|
super(ACmix, self).__init__()
|
|
self.in_planes = in_planes
|
|
self.out_planes = out_planes
|
|
self.head = head
|
|
self.kernel_att = kernel_att
|
|
self.kernel_conv = kernel_conv
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.rate1 = torch.nn.Parameter(torch.Tensor(1))
|
|
self.rate2 = torch.nn.Parameter(torch.Tensor(1))
|
|
self.head_dim = self.out_planes // self.head
|
|
|
|
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
|
|
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
|
|
self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
|
|
self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)
|
|
|
|
self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
|
|
self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
|
|
self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
|
|
self.softmax = torch.nn.Softmax(dim=1)
|
|
|
|
self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
|
|
self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
|
|
kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
|
|
stride=stride)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init_rate_half(self.rate1)
|
|
init_rate_half(self.rate2)
|
|
kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
|
|
for i in range(self.kernel_conv * self.kernel_conv):
|
|
kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
|
|
kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
|
|
self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
|
|
self.dep_conv.bias = init_rate_0(self.dep_conv.bias)
|
|
|
|
def forward(self, x):
|
|
q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
|
|
scaling = float(self.head_dim) ** -0.5
|
|
b, c, h, w = q.shape
|
|
h_out, w_out = h // self.stride, w // self.stride
|
|
|
|
pe = self.conv_p(position(h, w, x.is_cuda))
|
|
|
|
q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
|
|
k_att = k.view(b * self.head, self.head_dim, h, w)
|
|
v_att = v.view(b * self.head, self.head_dim, h, w)
|
|
|
|
if self.stride > 1:
|
|
q_att = stride(q_att, self.stride)
|
|
q_pe = stride(pe, self.stride)
|
|
else:
|
|
q_pe = pe
|
|
|
|
unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
|
|
self.kernel_att * self.kernel_att, h_out,
|
|
w_out) # b*head, head_dim, k_att^2, h_out, w_out
|
|
unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
|
|
w_out) # 1, head_dim, k_att^2, h_out, w_out
|
|
|
|
att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
|
|
1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
|
|
att = self.softmax(att)
|
|
|
|
out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
|
|
h_out, w_out)
|
|
out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)
|
|
|
|
f_all = self.fc(torch.cat(
|
|
[q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
|
|
v.view(b, self.head, self.head_dim, h * w)], 1))
|
|
f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
|
|
|
|
out_conv = self.dep_conv(f_conv)
|
|
|
|
return self.rate1 * out_att + self.rate2 * out_conv
|
|
|
|
|
|
# 输入 N C H W, 输出 N C H W
|
|
if __name__ == '__main__':
|
|
block = ACmix(in_planes=64, out_planes=64)
|
|
input = torch.rand(1, 64, 64, 64)
|
|
output = block(input)
|
|
print(input.size(), output.size())
|