15 lines
303 B
Python
15 lines
303 B
Python
|
from torch.autograd import Function
|
||
|
from torch.optim import SGD
|
||
|
|
||
|
|
||
|
class BinaryActivation(Function):
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
|
return (x.sign() + 1.) / 2.
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
return grad_output.clone()
|