Skip to content

Instantly share code, notes, and snippets.

@inoryy
Last active January 9, 2019 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save inoryy/de233ac13ecb97484e950c48d87aa688 to your computer and use it in GitHub Desktop.
Save inoryy/de233ac13ecb97484e950c48d87aa688 to your computer and use it in GitHub Desktop.
import tensorflow as tf
class TfCategorical:
def __init__(self, logits):
self.logits = logits
self.probs = tf.nn.softmax(logits)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
def masked_sample(self, mask=None):
if not mask:
return self.sample()
u = tf.random_uniform(tf.shape(self.logits)) * mask + 1e-12
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
def entropy(self):
return tf_cross_entropy(self.logits, self.probs)
def logli(self, indices):
return -self.neglogli(indices)
def kl_div(self, other):
return tf_cross_entropy(other.logits, self.probs) - tf_cross_entropy(self.logits, self.probs)
def neglogli(self, actions):
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=actions)
class TfMultiCategorical:
def __init__(self, logits):
self.dists = [TfCategorical(l) for l in logits]
def sample(self):
return [d.sample() for d in self.dists]
def masked_sample(self, masks):
return [d.masked_sample(m) for d, m in zip(self.dists, masks)]
def entropy(self):
return sum([d.entropy() for d in self.dists])
def logli(self, indices):
return -self.neglogli(indices)
def kl_div(self, others):
return sum([d.kl_div(o) for d, o in zip(self.dists, others)])
def neglogli(self, actions):
return sum([d.neglogli(a) for d, a in zip(self.dists, actions)])
def tf_cross_entropy(logits, probs):
"""
Alias function to reduce text clutter
"""
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=probs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment