Skip to content

Instantly share code, notes, and snippets.

@panchishin
Created November 23, 2020 05:23
Show Gist options
  • Save panchishin/7ca890e686a0eb981286e56b4b1ecda3 to your computer and use it in GitHub Desktop.
Save panchishin/7ca890e686a0eb981286e56b4b1ecda3 to your computer and use it in GitHub Desktop.
Random Categorical Selection with Gradient capability
class RandomCategorical(L.Layer):
def __init__(self, num_classes=10, factor=0.999):
super(RandomCategorical, self).__init__()
self.num_classes = num_classes
self.factor = factor
@tf.function
def call(self, input):
sample = tf.random.categorical(logits=input, num_samples=1)
hot = tf.reshape( tf.one_hot( sample, depth=self.num_classes ), shape=[-1,self.num_classes] )
return hot * self.factor + tf.nn.sigmoid(input) * (1. - self.factor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment