Skip to content

Instantly share code, notes, and snippets.

@cedricconol
Created May 23, 2020 03:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cedricconol/d2d4014e646b651bf0e41e4c98771307 to your computer and use it in GitHub Desktop.
Save cedricconol/d2d4014e646b651bf0e41e4c98771307 to your computer and use it in GitHub Desktop.
Linear regression TF2 update params
def update(self, X, y, learning_rate):
with tf.GradientTape(persistent=True) as g:
loss = self.mse(y, self.predict(X))
print("Loss: ", loss)
dy_dm = g.gradient(loss, self.m)
dy_db = g.gradient(loss, self.b)
self.m.assign_sub(learning_rate * dy_dm)
self.b.assign_sub(learning_rate * dy_db)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment