Skip to content

Instantly share code, notes, and snippets.

@Jeffwan
Last active April 9, 2021 05:40
Show Gist options
  • Save Jeffwan/65e8db172ed7686f3e17dd16b943ce57 to your computer and use it in GitHub Desktop.
Save Jeffwan/65e8db172ed7686f3e17dd16b943ce57 to your computer and use it in GitHub Desktop.
import argparse
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
import os
import ray
from ray import tune
from ray.util.sgd.tf.tf_trainer import TFTrainer, TFTrainable
NUM_TRAIN_SAMPLES = 1000
NUM_TEST_SAMPLES = 400
HEAD_SERVICE_IP_ENV = "EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_HOST"
HEAD_SERVICE_CLIENT_PORT_ENV = "EXAMPLE_CLUSTER_RAY_HEAD_SERVICE_PORT_CLIENT"
head_service_ip = os.environ[HEAD_SERVICE_IP_ENV]
client_port = os.environ[HEAD_SERVICE_CLIENT_PORT_ENV]
def create_config(batch_size):
return {
# todo: batch size needs to scale with # of workers
"batch_size": batch_size,
"fit_config": {
"steps_per_epoch": NUM_TRAIN_SAMPLES // batch_size
},
"evaluate_config": {
"steps": NUM_TEST_SAMPLES // batch_size,
}
}
def linear_dataset(a=2, size=1000):
x = np.random.rand(size)
y = x / 2
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
return x, y
def simple_dataset(config):
batch_size = config["batch_size"]
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
batch_size)
test_dataset = test_dataset.repeat().batch(batch_size)
return train_dataset, test_dataset
def simple_model(config):
model = Sequential([Dense(10, input_shape=(1, )), Dense(1)])
model.compile(
optimizer="sgd",
loss="mean_squared_error",
metrics=["mean_squared_error"])
return model
def train_example(num_replicas=1, batch_size=128, use_gpu=False):
trainer = TFTrainer(
model_creator=simple_model,
data_creator=simple_dataset,
num_replicas=num_replicas,
use_gpu=use_gpu,
verbose=True,
config=create_config(batch_size))
# model baseline performance
start_stats = trainer.validate()
print(start_stats)
# train for 2 epochs
trainer.train()
trainer.train()
# model performance after training (should improve)
end_stats = trainer.validate()
print(end_stats)
# sanity check that training worked
dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
dmse = (end_stats["validation_mean_squared_error"] -
start_stats["validation_mean_squared_error"])
print(f"dLoss: {dloss}, dMSE: {dmse}")
if dloss > 0 or dmse > 0:
print("training sanity check failed. loss increased!")
else:
print("success!")
def tune_example(num_replicas=1, use_gpu=False):
config = {
"model_creator": simple_model,
"data_creator": simple_dataset,
"num_replicas": num_replicas,
"use_gpu": use_gpu,
"trainer_config": create_config(batch_size=128)
}
analysis = tune.run(
TFTrainable,
num_samples=2,
config=config,
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Ray")
parser.add_argument(
"--num-replicas",
"-n",
type=int,
default=1,
help="Sets number of replicas for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--tune", action="store_true", default=False, help="Tune training")
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=2)
else:
# ray.init(address=args.address)
ray.util.connect(f"{head_service_ip}:{client_port}")
if args.tune:
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
else:
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment