Skip to content

Instantly share code, notes, and snippets.

@edoakes
Last active January 27, 2021 16:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edoakes/84fb783b05b5a256a3afdbefbd655cdd to your computer and use it in GitHub Desktop.
Save edoakes/84fb783b05b5a256a3afdbefbd655cdd to your computer and use it in GitHub Desktop.
Ray for training + Ray Serve for inference
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("train_s3_urls")
parser.add_argument("inference_s3_urls")
parser.add_argument("output_path")
import ray
from ray import serve
ray.init(address="auto")
@ray.remote(num_gpus=4)
class Trainer:
def __init__(self, initial_weights):
# Configure this to use GPUs.
self.model = init_model(initial_weights)
def train(self, s3_urls):
# Fetch s3 data. If the same images are shared across multiple
# processes, it's likely better to fetch them on the driver and
# put them in the object store to optimize.
data = self.fetch_s3_images(s3_urls)
self.model.train(data)
return self.model.weights
class Servable:
def __init__(self, trained_weights):
self.model = init_model(trained_weights)
@serve.accept_batch
def __call__(self, s3_urls):
data = self.fetch_s3_images(s3_urls)
return self.model.inference(data)
def main(args):
# Can do this with multiple model types or duplicate the Trainer.
trainer = Trainer.remote(initial_weights)
result_weights = ray.get(trainer.train.remote(args.train_s3_urls))
client = serve.start(http_host=None)
serve.create_backend(
"model1", Servable, result_weights, config=serve.BackendConfig(num_replicas=20))
serve.create_endpoint("model1", backend="model1")
handle = serve.get_handle("model1")
refs = [handle.remote(s3_url) for s3_url in args.inference_urls]
for result in ray.get(refs):
write_to_s3(args.output_path, result)
if __name__ == "__main__":
main(parser.parse_args())
@mathetes87
Copy link

Hey Edward, we're having trouble when importing the Trainer class. We make the import using importlib like this:

model = "models.mlp.AdaptiveMLP"
module_name, _, class_name = model.rpartition(".")
Trainer = importlib.import_module(module_name)

The error we get is
'ActorClass(AdaptiveMLP)' object has no attribute '__mro__'

What would you recommend here? Maybe import the module in some other way?

@mathetes87
Copy link

Ok, we found the problem. The Trainer was defined using the dataclass decorator and we had to set the ray.remote as the first one. After that it worked!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment