Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/ExternalAttention.py

45 lines
1.4 KiB
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
import numpy as np
import torch
from torch import nn
from torch.nn import init
class ExternalAttention(nn.Module):
def __init__(self, d_model, S=64):
super().__init__()
self.mk = nn.Linear(d_model, S, bias=False)
self.mv = nn.Linear(S, d_model, bias=False)
self.softmax = nn.Softmax(dim=1)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, queries):
attn = self.mk(queries) # bs,n,S
attn = self.softmax(attn) # bs,n,S
attn = attn / torch.sum(attn, dim=2, keepdim=True) # bs,n,S
out = self.mv(attn) # bs,n,d_model
return out
# 输入 B C N, 输出 B C N
if __name__ == '__main__':
block = ExternalAttention(d_model=64, S=8).cuda()
input = torch.rand(64, 64, 64).cuda()
output = block(input)
print(input.size(), output.size())