Created
February 17, 2021 19:58
-
-
Save MakGulati/58532dc878e0966c7c3d31fa684be6d7 to your computer and use it in GitHub Desktop.
with normal function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from tensorflow.keras.datasets import mnist | |
from ray.tune.integration.keras import TuneReportCallback | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--smoke-test", action="store_true", help="Finish quickly for testing") | |
args, _ = parser.parse_known_args() | |
def train_mnist(config): | |
# https://github.com/tensorflow/tensorflow/issues/32159 | |
import tensorflow as tf | |
batch_size = 128 | |
num_classes = 10 | |
epochs = 4 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train, x_test = x_train / 255.0, x_test / 255.0 | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(config["hidden"], activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(num_classes, activation="softmax") | |
]) | |
model.compile( | |
loss="sparse_categorical_crossentropy", | |
optimizer=tf.keras.optimizers.SGD( | |
lr=config["lr"], momentum=config["momentum"]), | |
metrics=["accuracy"]) | |
model.fit( | |
x_train, | |
y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
verbose=0, | |
validation_data=(x_test, y_test), | |
callbacks=[TuneReportCallback({ | |
"mean_accuracy": "acc" | |
})]) | |
if __name__ == "__main__": | |
import ray | |
from ray import tune | |
from ray.tune.schedulers import AsyncHyperBandScheduler | |
mnist.load_data() # we do this on the driver because it's not threadsafe | |
ray.init(num_cpus=4 if args.smoke_test else None) | |
sched = AsyncHyperBandScheduler( | |
time_attr="training_iteration", max_t=400, grace_period=20) | |
analysis = tune.run( | |
train_mnist, | |
name="exp", | |
scheduler=sched, | |
metric="mean_accuracy", | |
mode="max", | |
stop={ | |
"mean_accuracy": 0.90, | |
"training_iteration": 5 if args.smoke_test else 300 | |
}, | |
num_samples=10, | |
resources_per_trial={ | |
"cpu": 2, | |
"gpu": 0 | |
}, | |
config={ | |
"threads": 2, | |
"lr": tune.uniform(0.001, 0.1), | |
"momentum": tune.uniform(0.1, 0.9), | |
"hidden": tune.randint(32, 512), | |
}) | |
print("Best hyperparameters found were: ", analysis.best_config) | |
# analysis = tune.run( | |
# train_mnist, | |
# metric="mean_accuracy", | |
# mode="max", | |
# config={ | |
# "lr": tune.grid_search([0.001, 0.01, 0.1]), | |
# "momentum": tune.uniform(0.1, 0.2), | |
# "hidden": tune.randint(32, 33), | |
# }, | |
# ) | |
print("Best config: ", analysis.get_best_config(metric="mean_loss", mode="min")) | |
# Get a dataframe for analyzing trial results. | |
df = analysis.results_df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment