Created
February 14, 2021 10:08
-
-
Save MakGulati/952b3eb5a1d97775c88212fdaaf6206f to your computer and use it in GitHub Desktop.
ray_class_prob
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import tensorflow as tf | |
from tensorflow.keras.datasets import mnist | |
from ray.tune.integration.keras import TuneReportCallback | |
import ray | |
from ray import tune | |
from ray.tune.schedulers import AsyncHyperBandScheduler | |
from ray.tune import Trainable | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--smoke-test", action="store_true", help="Finish quickly for testing" | |
) | |
args, _ = parser.parse_known_args() | |
class MNIST: | |
def __init__(self): | |
self.batch_size = 128 | |
self.num_classes = 10 | |
self.epochs = 5 | |
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() | |
self.x_train, self.x_test = self.x_train / 255.0, self.x_test / 255.0 | |
self.model = None | |
print(f"class MNIST initialised") | |
def train_mnist(self, cfg): | |
self.model = tf.keras.models.Sequential( | |
[ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(8, activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax"), | |
] | |
) | |
self.model.compile( | |
loss="sparse_categorical_crossentropy", | |
optimizer=tf.keras.optimizers.SGD(lr=cfg["lr"], momentum=cfg["momentum"]), | |
metrics=["accuracy"], | |
) | |
self.model.fit( | |
self.x_train, | |
self.y_train, | |
batch_size=self.batch_size, | |
epochs=self.epochs, | |
verbose=2, | |
validation_data=(self.x_test, self.y_test), | |
callbacks=[TuneReportCallback({"mean_accuracy": "acc"})], | |
) | |
if __name__ == "__main__": | |
import ray | |
from ray import tune | |
from ray.tune.schedulers import AsyncHyperBandScheduler | |
mnist.load_data() # we do this on the driver because it's not threadsafe | |
ray.init(num_cpus=4 if args.smoke_test else None) | |
sched = AsyncHyperBandScheduler( | |
time_attr="training_iteration", max_t=400, grace_period=20 | |
) | |
mnist_obj = MNIST() | |
analysis = tune.run( | |
mnist_obj.train_mnist, | |
name="exp", | |
scheduler=sched, | |
metric="mean_accuracy", | |
mode="max", | |
stop={ | |
"mean_accuracy": 0.99, | |
"training_iteration": 5 if args.smoke_test else 300, | |
}, | |
num_samples=10, | |
resources_per_trial={"cpu": 2, "gpu": 0}, | |
config={ | |
"threads": 2, | |
"lr": tune.uniform(0.001, 0.1), | |
"momentum": tune.uniform(0.1, 0.9), | |
"hidden": tune.randint(32, 512), | |
}, | |
) | |
print("Best hyperparameters found were: ", analysis.best_config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment