Skip to content

Instantly share code, notes, and snippets.

@MakGulati
Created February 17, 2021 19:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MakGulati/58532dc878e0966c7c3d31fa684be6d7 to your computer and use it in GitHub Desktop.
Save MakGulati/58532dc878e0966c7c3d31fa684be6d7 to your computer and use it in GitHub Desktop.
with normal function
import argparse
from tensorflow.keras.datasets import mnist
from ray.tune.integration.keras import TuneReportCallback
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
def train_mnist(config):
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
batch_size = 128
num_classes = 10
epochs = 4
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(config["hidden"], activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(num_classes, activation="softmax")
])
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(
lr=config["lr"], momentum=config["momentum"]),
metrics=["accuracy"])
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneReportCallback({
"mean_accuracy": "acc"
})])
if __name__ == "__main__":
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
mnist.load_data() # we do this on the driver because it's not threadsafe
ray.init(num_cpus=4 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", max_t=400, grace_period=20)
analysis = tune.run(
train_mnist,
name="exp",
scheduler=sched,
metric="mean_accuracy",
mode="max",
stop={
"mean_accuracy": 0.90,
"training_iteration": 5 if args.smoke_test else 300
},
num_samples=10,
resources_per_trial={
"cpu": 2,
"gpu": 0
},
config={
"threads": 2,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
"hidden": tune.randint(32, 512),
})
print("Best hyperparameters found were: ", analysis.best_config)
# analysis = tune.run(
# train_mnist,
# metric="mean_accuracy",
# mode="max",
# config={
# "lr": tune.grid_search([0.001, 0.01, 0.1]),
# "momentum": tune.uniform(0.1, 0.2),
# "hidden": tune.randint(32, 33),
# },
# )
print("Best config: ", analysis.get_best_config(metric="mean_loss", mode="min"))
# Get a dataframe for analyzing trial results.
df = analysis.results_df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment