Created
November 25, 2021 03:45
-
-
Save lkevinzc/9fd77152e6c79d3511133c36d8f72391 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import time | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from absl import app, flags, logging | |
FLAGS = flags.FLAGS | |
flags.DEFINE_boolean('use_profiler', False, 'Use profiler for tracing.') | |
flags.DEFINE_integer('total_batch', 64, 'Total batch size.') | |
def _forward( | |
images, | |
is_training: bool, | |
) -> jnp.ndarray: | |
"""Forward application of the resnet.""" | |
net = hk.nets.ResNet50(1000, resnet_v2=True) | |
return net(images, is_training=is_training) | |
# Transform our forwards function into a pair of pure functions. | |
forward = hk.transform_with_state(_forward) | |
def initial_state(rng: jnp.ndarray, batch): | |
"""Computes the initial network state.""" | |
params, state = forward.init(rng, batch, | |
is_training=True) # must be true for init | |
return params, state | |
@functools.partial(jax.pmap) | |
def step(train_state, data): | |
params, state = train_state | |
logits, state = forward.apply(params, state, None, data, is_training=False) | |
return logits, state | |
def main(argv): | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
local_device_count = jax.local_device_count() | |
logging.info(f"local device count: {local_device_count}") | |
rng = jax.random.PRNGKey(42) | |
rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape) | |
batch = np.zeros([ | |
local_device_count, FLAGS.total_batch // local_device_count, 3, 224, | |
224 | |
], | |
dtype=np.float32) | |
logging.info(f"testing with data shape={batch.shape}") | |
st = time.perf_counter() | |
train_state = jax.pmap(initial_state)(rng, batch) | |
logging.info( | |
f"PMAP Parameters initialized in {time.perf_counter() - st} seconds") | |
_ = step(train_state, batch) | |
logging.info("Step function initialized") | |
if FLAGS.use_profiler: | |
jax.profiler.start_trace("res101_pmap") | |
start = time.perf_counter() | |
for i in range(500): | |
_ = step(train_state, batch) | |
if FLAGS.use_profiler and i == 200: | |
jax.profiler.stop_trace() | |
logging.info(f"iter 500 takes: {time.perf_counter() - start} s") | |
if __name__ == '__main__': | |
app.run(main) | |
# python3 res50_pmap_on_host.py --total_batch 256 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment