Skip to content

Instantly share code, notes, and snippets.

@ronzillia
Created May 14, 2018 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ronzillia/65cb22aea02cdfede3c1008ab7b9543d to your computer and use it in GitHub Desktop.
Save ronzillia/65cb22aea02cdfede3c1008ab7b9543d to your computer and use it in GitHub Desktop.
def classification_loss(self,logit,input_y_classification,class_weight):
labels=input_y_classification
weight_per_label = tf.transpose( tf.matmul(labels, tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label
entropy=tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=labels, name="xent_raw")
xent = tf.multiply(weight_per_label, entropy) #shape [1, batch_size]
cost = tf.reduce_mean(xent) #shape 1
self._summaries['classification_loss'] = tf.summary.scalar('classification_loss', cost)
return cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment