Tan_pytorch_segmentation/pytorch_segmentation/Plug-and-Play/(arxiv2023.4)SpectFormer.py

39 lines
1.0 KiB
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
import math
import torch.fft
import torch
import torch.nn as nn
class SpectralGatingNetwork(nn.Module):
def __init__(self, dim, h=8, w=5):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.view(B, a, b, C)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x
# 输入 B, N, C, 输出 B, N, C
if __name__ == '__main__':
block = SpectralGatingNetwork(64).cuda()
input = torch.rand(1, 64, 64).cuda()
output = block(input)
print(input.size(), output.size())