Skip to content

Instantly share code, notes, and snippets.

@kvchen
Created December 2, 2016 06:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kvchen/ea1212d0af6fee3d7dda797914f0cdac to your computer and use it in GitHub Desktop.
Save kvchen/ea1212d0af6fee3d7dda797914f0cdac to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import click
import json
import mkl
import numpy as np
import ray
from ray.array.distributed import core as rdc
from timeit import default_timer as timer
def benchmark_matmul(block_size, matrix_size, seed=0):
np.random.seed(seed)
rdc.BLOCK_SIZE = block_size
local_a = np.random.rand(matrix_size, matrix_size)
local_b = np.random.rand(matrix_size, matrix_size)
start = timer()
result_ids = [rdc.numpy_to_dist.remote(m) for m in (local_a, local_b)]
remote_a, remote_b = ray.get(result_ids)
# Matrix multiplication proceeds synchronously when fetching results
remote_result = ray.get(rdc.dot.remote(remote_a, remote_b))
ray.get(rdc.assemble.remote(remote_result))
end = timer()
return end - start
@click.command()
@click.argument('output', nargs=1, type=str, default="output.json")
@click.option('--block-size', type=int, default=100,
help='the block size used when distributing matrices')
@click.option('--matrix-size', type=int, default=1000,
help='the width and height of each generated square matrix')
@click.option('--num-workers', type=int, default=1,
help='how many Ray workers to assign to the task')
@click.option('--num-threads', type=int, default=1,
help='how many threads to use in MKL BLAS')
def main(output, block_size, matrix_size, num_workers, num_threads):
ray.init(start_ray_local=True, num_workers=num_workers)
ray.register_class(rdc.DistArray)
# Set the number of Numpy threads to 1
mkl.set_num_threads(num_threads)
time = benchmark_matmul(block_size, matrix_size)
results = {
'block_size': block_size,
'matrix_size': matrix_size,
'num_workers': num_workers,
'time': time,
}
with open(output, 'w') as outfile:
json.dump(results, outfile)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment