from torch.autograd import Function from torch.optim import SGD class BinaryActivation(Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return (x.sign() + 1.) / 2. @staticmethod def backward(ctx, grad_output): return grad_output.clone()