Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created July 4, 2024 00:55
Show Gist options
  • Save richardliaw/197ffb3dc2ff50b095e0da5543d8bad0 to your computer and use it in GitHub Desktop.
Save richardliaw/197ffb3dc2ff50b095e0da5543d8bad0 to your computer and use it in GitHub Desktop.
import ray.data
class DataGenerator:
def __init__(self, permute_config):
device = torch.device("cuda")
self.model = Model().to(device)
self.config = permute_config
def __call__(self, input):
for test_input in self.permute(permute_config, input):
yield self.model(input)
ds = ray.data.read_json("s3://path_to_bucket/user_prompts.json")
ds.map_batches(DataGenerator, num_gpus=1, concurrency=100) # run on 100 GPUs
ds.write_json("s3://path_to_bucket/generated_prompts.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment