Skip to content

Instantly share code, notes, and snippets.

@stephanie-wang
Created July 9, 2020 23:30
Show Gist options
  • Save stephanie-wang/4e4165c5af284b31470ea429bcf1f8ae to your computer and use it in GitHub Desktop.
Save stephanie-wang/4e4165c5af284b31470ea429bcf1f8ae to your computer and use it in GitHub Desktop.
import ray
import time
@ray.remote
def foo(arg):
return
@ray.remote
class Cache:
def __init__(self):
self.cache = {}
def foo(self, arg):
if arg not in self.cache:
self.cache[arg] = foo.remote(arg)
return self.cache[arg]
ray.init()
ray.get([foo.remote(0) for _ in range(100)])
start = time.time()
ray.get([foo.remote(arg) for arg in range(10000)])
end = time.time()
print("10000 tasks", 10000 / (end - start))
cache = Cache.remote()
ray.get(cache.foo.remote(0))
start = time.time()
ray.get(ray.get([cache.foo.remote(arg) for arg in range(10000)]))
end = time.time()
print("10000 cached tasks, miss", 10000 / (end - start))
start = time.time()
ray.get(ray.get([cache.foo.remote(arg) for arg in range(10000)]))
end = time.time()
print("10000 cached tasks, hit", 10000 / (end - start))
del cache
caches = [Cache.remote() for _ in range(4)]
ray.get([cache.foo.remote(0) for cache in caches])
start = time.time()
ray.get(ray.get([caches[arg % len(caches)].foo.remote(arg) for arg in range(10000)]))
end = time.time()
print("10000 sharded and cached tasks, miss", 10000 / (end - start))
start = time.time()
ray.get(ray.get([caches[arg % len(caches)].foo.remote(arg) for arg in range(10000)]))
end = time.time()
print("10000 sharded and cached tasks, hit", 10000 / (end - start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment