Skip to content

Instantly share code, notes, and snippets.

@ShopifyEng
Last active April 5, 2022 20:40
Merlin
import ray
from ray.train import Trainer
def train_func(config):
"""
Training function passed to Ray Train to be distributed.
Loads the dataset, builds, fits and saves the model.
"""
...
def main():
ray.init("<RAY_CLUSTER_ADDRESS>")
config = {...}
trainer = Trainer(backend="tensorflow", num_workers=num_workers, use_gpu=True)
trainer.start()
log.info("Starting to train.")
start = time.time()
results = trainer.run(train_func, config=config)
log.info(f"Training finished in {time.time() - start} seconds.")
trainer.shutdown()
log.info(f"Training results: {results[0]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment