45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import init
|
||
|
|
||
|
|
||
|
class ExternalAttention(nn.Module):
|
||
|
|
||
|
def __init__(self, d_model, S=64):
|
||
|
super().__init__()
|
||
|
self.mk = nn.Linear(d_model, S, bias=False)
|
||
|
self.mv = nn.Linear(S, d_model, bias=False)
|
||
|
self.softmax = nn.Softmax(dim=1)
|
||
|
self.init_weights()
|
||
|
|
||
|
def init_weights(self):
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Conv2d):
|
||
|
init.kaiming_normal_(m.weight, mode='fan_out')
|
||
|
if m.bias is not None:
|
||
|
init.constant_(m.bias, 0)
|
||
|
elif isinstance(m, nn.BatchNorm2d):
|
||
|
init.constant_(m.weight, 1)
|
||
|
init.constant_(m.bias, 0)
|
||
|
elif isinstance(m, nn.Linear):
|
||
|
init.normal_(m.weight, std=0.001)
|
||
|
if m.bias is not None:
|
||
|
init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, queries):
|
||
|
attn = self.mk(queries) # bs,n,S
|
||
|
attn = self.softmax(attn) # bs,n,S
|
||
|
attn = attn / torch.sum(attn, dim=2, keepdim=True) # bs,n,S
|
||
|
out = self.mv(attn) # bs,n,d_model
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
# 输入 B C N, 输出 B C N
|
||
|
if __name__ == '__main__':
|
||
|
block = ExternalAttention(d_model=64, S=8).cuda()
|
||
|
input = torch.rand(64, 64, 64).cuda()
|
||
|
output = block(input)
|
||
|
print(input.size(), output.size())
|