Skip to content

Instantly share code, notes, and snippets.

@lkevinzc
Created November 25, 2021 03:45
Show Gist options
  • Save lkevinzc/9fd77152e6c79d3511133c36d8f72391 to your computer and use it in GitHub Desktop.
Save lkevinzc/9fd77152e6c79d3511133c36d8f72391 to your computer and use it in GitHub Desktop.
import functools
import time
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from absl import app, flags, logging
FLAGS = flags.FLAGS
flags.DEFINE_boolean('use_profiler', False, 'Use profiler for tracing.')
flags.DEFINE_integer('total_batch', 64, 'Total batch size.')
def _forward(
images,
is_training: bool,
) -> jnp.ndarray:
"""Forward application of the resnet."""
net = hk.nets.ResNet50(1000, resnet_v2=True)
return net(images, is_training=is_training)
# Transform our forwards function into a pair of pure functions.
forward = hk.transform_with_state(_forward)
def initial_state(rng: jnp.ndarray, batch):
"""Computes the initial network state."""
params, state = forward.init(rng, batch,
is_training=True) # must be true for init
return params, state
@functools.partial(jax.pmap)
def step(train_state, data):
params, state = train_state
logits, state = forward.apply(params, state, None, data, is_training=False)
return logits, state
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
local_device_count = jax.local_device_count()
logging.info(f"local device count: {local_device_count}")
rng = jax.random.PRNGKey(42)
rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape)
batch = np.zeros([
local_device_count, FLAGS.total_batch // local_device_count, 3, 224,
224
],
dtype=np.float32)
logging.info(f"testing with data shape={batch.shape}")
st = time.perf_counter()
train_state = jax.pmap(initial_state)(rng, batch)
logging.info(
f"PMAP Parameters initialized in {time.perf_counter() - st} seconds")
_ = step(train_state, batch)
logging.info("Step function initialized")
if FLAGS.use_profiler:
jax.profiler.start_trace("res101_pmap")
start = time.perf_counter()
for i in range(500):
_ = step(train_state, batch)
if FLAGS.use_profiler and i == 200:
jax.profiler.stop_trace()
logging.info(f"iter 500 takes: {time.perf_counter() - start} s")
if __name__ == '__main__':
app.run(main)
# python3 res50_pmap_on_host.py --total_batch 256
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment