Skip to content

Instantly share code, notes, and snippets.

@sharadmv
Last active March 23, 2023 13:59
Show Gist options
  • Save sharadmv/be24b3107cf9b8bf027ea8e2f177882e to your computer and use it in GitHub Desktop.
Save sharadmv/be24b3107cf9b8bf027ea8e2f177882e to your computer and use it in GitHub Desktop.
JAX progress bar
"""Module for the JAX progress bar."""
from __future__ import annotations
import abc
import threading
import types
from typing import Any, Callable, List, Optional, Set, Tuple, Type
import jax
import rich.console
import rich.live
import rich.progress
_ProgressCallable = Callable[[float, str, bool], Any]
class Progress(metaclass=abc.ABCMeta):
"""Abstract class for implementing a progress bar."""
def __init__(self, total: float, description: str, ordered: bool):
self._total = total
self._description = description
self._ordered = ordered
@abc.abstractmethod
def _start(self):
pass
@abc.abstractmethod
def _update(self, value: float):
pass
@abc.abstractmethod
def _stop(self):
pass
def start(self):
jax.debug.callback(self._start, ordered=self._ordered)
def update(self, value):
jax.debug.callback(self._update, value, ordered=self._ordered)
def stop(self):
jax.debug.callback(self._stop, ordered=self._ordered)
def __enter__(self):
self.start()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
):
self.stop()
class RichProgress(Progress):
"""Progress bar implemented using the `rich` library."""
_live = rich.live.Live()
_active_progress: List[rich.progress.Progress] = []
def __init__(self, total: float, description: str, ordered: bool):
super().__init__(total=total, description=description, ordered=ordered)
self._progress = rich.progress.Progress(auto_refresh=False)
self._task = self._progress.add_task(self._description, total=total)
self._lock = threading.Lock()
def _start(self):
if not self._active_progress:
self._live.start()
self._active_progress.append(self._progress)
self._live.update(rich.console.Group(*self._active_progress))
def _update(self, value):
with self._lock:
self._progress.update(self._task, completed=value)
self._progress.refresh()
def _stop(self):
self._active_progress.pop()
self._live.update(rich.console.Group(*self._active_progress))
if not self._active_progress:
self._live.stop()
_progress_bars: Set[Tuple[int, _ProgressCallable]] = set()
def get_progress_bar() -> _ProgressCallable:
progress_bars = sorted(_progress_bars, key=lambda x: -x[0])
for _, progress_bar, in progress_bars:
return progress_bar
assert False, 'Should have hit default progress_bar'
def register_progress_bar(progress_bar, priority):
_progress_bars.add((priority, progress_bar))
return progress_bar
def _rich_progress(total: float, description: str, ordered: bool):
return RichProgress(total=total, description=description, ordered=ordered)
register_progress_bar(_rich_progress, -1)
def progress(total: float,
description: str = 'Progress:',
ordered: bool = True):
return get_progress_bar()(total, description, ordered)
if __name__ == "__main__":
from jax import lax
import jax.numpy as jnp
import time
@jax.jit
def f(x):
with progress(total=20, description="Outer", ordered=True) as outer_pbar:
def outer_body(i, x):
outer_pbar.update(i)
with progress(total=10, description="Inner", ordered=True) as inner_pbar:
def inner_body(j, x):
inner_pbar.update(j)
def _sleep(x):
time.sleep(0.1)
return x
x = jax.pure_callback(_sleep, x, x)
return x
x = lax.fori_loop(0, 10, inner_body, x)
return x + 1
return lax.fori_loop(0, 20, outer_body, x)
f(1.)
jax.effects_barrier()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment