Skip to content

Instantly share code, notes, and snippets.

@arunmallya
Created February 20, 2018 20:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arunmallya/668e0f31aedb3563c3fa020b4116e8a8 to your computer and use it in GitHub Desktop.
Save arunmallya/668e0f31aedb3563c3fa020b4116e8a8 to your computer and use it in GitHub Desktop.
Autograd snippet for Binarizer
DEFAULT_THRESHOLD = 5e-3
class Binarizer(torch.autograd.Function):
"""Binarizes {0, 1} a real valued tensor."""
def __init__(self, threshold=DEFAULT_THRESHOLD):
super(Binarizer, self).__init__()
self.threshold = threshold
def forward(self, inputs):
outputs = inputs.clone()
outputs[inputs.le(self.threshold)] = 0
outputs[inputs.gt(self.threshold)] = 1
return outputs
def backward(self, gradOutput):
return gradOutput
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment