Skip to content

Instantly share code, notes, and snippets.

@ronzillia
Created May 14, 2018 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ronzillia/e571f91c25e1202a73d0910ad89f26d2 to your computer and use it in GitHub Desktop.
Save ronzillia/e571f91c25e1202a73d0910ad89f26d2 to your computer and use it in GitHub Desktop.
# balance data via loss for each batch
x_train, y_train,index = batch_data
label=np.argmax(y_train,axis=1)
num_pos=np.count_nonzero(label)
num_neg=len(label)-num_pos
pos_weight=np.true_divide(num_neg+1,len(label)+1)
neg_weight=np.true_divide(num_pos+1,len(label)+1)
class_weight=np.array([[neg_weight,pos_weight]])
# plus one to avoid all-zero and all-one problem
_ ,train_cost= sess.run([train_op,classification_loss_op], feed_dict={input_x: x_train,input_y_classification: y_train,bs_holder:batch_size,training_flag:True,tf_class_weight:class_weight})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment