Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
yaroslavvb / sharded_ps_benchmark.py
Last active December 27, 2022 06:25
Example of local cluster with multiple workers/training loops sharded parameter server
#!/usr/bin/env python
# Benchmark transferring data, part of troubleshooting https://github.com/tensorflow/tensorflow/issues/6116
#
# Take a independent workers communicating with b parameter shards
# Each worker tries to add to variables stored on parameter server as fast as
# possible.
#
# macbook
# ps=1: 1.6 GB/s
# ps=2: 2.6 GB/s
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))