45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
import os
|
|
import sys
|
|
import inspect
|
|
|
|
from torch import nn
|
|
import torch
|
|
|
|
|
|
class Partial_conv3(nn.Module):
|
|
|
|
def __init__(self, dim, n_div, forward):
|
|
super().__init__()
|
|
self.dim_conv3 = dim // n_div
|
|
self.dim_untouched = dim - self.dim_conv3
|
|
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
|
|
|
|
if forward == 'slicing':
|
|
self.forward = self.forward_slicing
|
|
elif forward == 'split_cat':
|
|
self.forward = self.forward_split_cat
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def forward_slicing(self, x):
|
|
# only for inference
|
|
x = x.clone() # !!! Keep the original input intact for the residual connection later
|
|
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
|
|
|
|
return x
|
|
|
|
def forward_split_cat(self, x):
|
|
# for training/inference
|
|
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
|
|
x1 = self.partial_conv3(x1)
|
|
x = torch.cat((x1, x2), 1)
|
|
|
|
return x
|
|
|
|
|
|
if __name__ == '__main__':
|
|
block = Partial_conv3(64, 2, 'split_cat').cuda()
|
|
input = torch.rand(1, 64, 64, 64).cuda()
|
|
output = block(input)
|
|
print(input.size(), output.size())
|