Skip to content

Instantly share code, notes, and snippets.

@MakGulati
Created February 14, 2021 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MakGulati/952b3eb5a1d97775c88212fdaaf6206f to your computer and use it in GitHub Desktop.
Save MakGulati/952b3eb5a1d97775c88212fdaaf6206f to your computer and use it in GitHub Desktop.
ray_class_prob
import argparse
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from ray.tune.integration.keras import TuneReportCallback
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import Trainable
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
class MNIST:
def __init__(self):
self.batch_size = 128
self.num_classes = 10
self.epochs = 5
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
self.x_train, self.x_test = self.x_train / 255.0, self.x_test / 255.0
self.model = None
print(f"class MNIST initialised")
def train_mnist(self, cfg):
self.model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(8, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(self.num_classes, activation="softmax"),
]
)
self.model.compile(
loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(lr=cfg["lr"], momentum=cfg["momentum"]),
metrics=["accuracy"],
)
self.model.fit(
self.x_train,
self.y_train,
batch_size=self.batch_size,
epochs=self.epochs,
verbose=2,
validation_data=(self.x_test, self.y_test),
callbacks=[TuneReportCallback({"mean_accuracy": "acc"})],
)
if __name__ == "__main__":
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
mnist.load_data() # we do this on the driver because it's not threadsafe
ray.init(num_cpus=4 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", max_t=400, grace_period=20
)
mnist_obj = MNIST()
analysis = tune.run(
mnist_obj.train_mnist,
name="exp",
scheduler=sched,
metric="mean_accuracy",
mode="max",
stop={
"mean_accuracy": 0.99,
"training_iteration": 5 if args.smoke_test else 300,
},
num_samples=10,
resources_per_trial={"cpu": 2, "gpu": 0},
config={
"threads": 2,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
"hidden": tune.randint(32, 512),
},
)
print("Best hyperparameters found were: ", analysis.best_config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment