Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb2
Created October 29, 2017 02:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb2/2da7177b0081e6399aba4052d5f89024 to your computer and use it in GitHub Desktop.
Save yaroslavvb2/2da7177b0081e6399aba4052d5f89024 to your computer and use it in GitHub Desktop.
resnet small
@graph_callable.graph_callable([])
def resnet_loss():
 “””Resnet loss from random input”””
 network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES)
 inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH])
 logits = network(inputs,True)
 cross_entropy = tf.losses.softmax_cross_entropy(logits=logits,
 onehot_labels=labels)
 return cross_entropy
loss_and_grads_fn = tfe.implicit_value_and_gradients(resnet_loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
losses = []
for i in range(500):
 loss, grads_and_vars = loss_and_grads_fn()
 optimizer.apply_gradients(grads_and_vars)
 print(loss)
 losses.append(loss.numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment