Skip to content

Instantly share code, notes, and snippets.

@BryanCutler
Last active August 5, 2019 17:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BryanCutler/68f14c6a9f1bcde2a22afaef55db0620 to your computer and use it in GitHub Desktop.
Save BryanCutler/68f14c6a9f1bcde2a22afaef55db0620 to your computer and use it in GitHub Desktop.
TensorFlow Arrow Blog Part 5 - Model Definition
def model_fit(ds):
"""Create and fit a Keras logistic regression model."""
# Build the Keras model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(2,),
activation='sigmoid'))
model.compile(optimizer='sgd', loss='mean_squared_error',
metrics=['accuracy'])
# Fit the model on the given dataset
model.fit(ds, epochs=5, shuffle=False)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment