Skip to content

Instantly share code, notes, and snippets.

@wangkuiyi
Created April 6, 2019 22:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wangkuiyi/b3d6383077e94f08185d3b9dfd3f5b64 to your computer and use it in GitHub Desktop.
Save wangkuiyi/b3d6383077e94f08185d3b9dfd3f5b64 to your computer and use it in GitHub Desktop.
Correct https://www.tensorflow.org/guide/eager#variables_and_optimizers to make it work with TenosrFlow 2.0 alpha
#!/usr/bin/env python
import tensorflow as tf
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.W = tf.Variable(5., name='weight')
self.B = tf.Variable(10., name='bias')
def call(self, inputs):
return inputs * self.W + self.B
# A toy dataset of points around 3 * x + 2
NUM_EXAMPLES = 2000
training_inputs = tf.random.normal([NUM_EXAMPLES])
noise = tf.random.normal([NUM_EXAMPLES])
training_outputs = training_inputs * 3 + 2 + noise
# The loss function to be optimized
def loss(model, inputs, targets):
error = model(inputs) - targets
return tf.reduce_mean(tf.square(error))
def grad(model, inputs, targets):
with tf.GradientTape() as tape:
loss_value = loss(model, inputs, targets)
return tape.gradient(loss_value, [model.W, model.B])
# Define:
# 1. A model.
# 2. Derivatives of a loss function with respect to model parameters.
# 3. A strategy for updating the variables based on the derivatives.
model = Model()
# optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
optimizer = tf.optimizers.SGD(learning_rate=0.01)
print("Initial loss: {:.3f}".format(
loss(model, training_inputs, training_outputs)))
# Training loop
for i in range(300):
grads = grad(model, training_inputs, training_outputs)
optimizer.apply_gradients(zip(grads, [model.W, model.B]))
if i % 20 == 0:
print("Loss at step {:03d}: {:.3f}".format(
i, loss(model, training_inputs, training_outputs)))
print("Final loss: {:.3f}".format(
loss(model, training_inputs, training_outputs)))
print("W = {}, B = {}".format(model.W.numpy(), model.B.numpy()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment