Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active December 14, 2023 22:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/e48252b85b0a1cee40f3346d05849dd6 to your computer and use it in GitHub Desktop.
Save woshiyyya/e48252b85b0a1cee40f3346d05849dd6 to your computer and use it in GitHub Desktop.
import ray
import ray.train
import numpy as np
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from dataclasses import dataclass
@dataclass
class DummyDataclass:
a: str
b: int
@ray.remote
class KVCache:
def __init__(self):
self.map = dict()
def put(self, key, val):
self.map[key] = ray.put(val)
def get(self, key):
return ray.get(self.map.get(key, None))
def clear(self):
self.map.clear()
def launch_kvcache():
return KVCache.options(
name=f"KVCache", namespace="debug", lifetime="detached", get_if_exists=True
).remote()
def get_kvcache():
return ray.get_actor("KVCache", namespace="debug")
def train_func():
trial_name = ray.train.get_context().get_trial_name()
world_rank = ray.train.get_context().get_world_rank()
dataclass_obj = DummyDataclass(a="1", b=2)
kvcache = get_kvcache()
kvcache.put.remote(
f"rank_{world_rank}",
{"string": trial_name, "nparray": np.random.randn(2, 2), "dataclass": dataclass_obj},
)
if __name__ == "__main__":
kvcache = launch_kvcache()
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=4),
)
trainer.fit()
for i in range(4):
print(ray.get(kvcache.get.remote(f"rank_{i}")))
ray.kill(kvcache)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment