Skip to content

Instantly share code, notes, and snippets.

@geoffreyvd
Last active March 10, 2021 15:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geoffreyvd/076e5129f4c1a01800c81644b72357db to your computer and use it in GitHub Desktop.
Save geoffreyvd/076e5129f4c1a01800c81644b72357db to your computer and use it in GitHub Desktop.
Straight through estimator for tensforflow/keras 2+
@tf.custom_gradient
def binairy_STE_after_sigmoid(x):
def grad(dy):
return dy
result = tf.round(x)
return result, grad
def model():
output_dim = 1
x = Dense(output_dim, activation='sigmoid')(x)
x = Lambda(binairy_STE_after_sigmoid)(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment