Last active
March 23, 2023 13:59
-
-
Save sharadmv/be24b3107cf9b8bf027ea8e2f177882e to your computer and use it in GitHub Desktop.
JAX progress bar
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Module for the JAX progress bar.""" | |
from __future__ import annotations | |
import abc | |
import threading | |
import types | |
from typing import Any, Callable, List, Optional, Set, Tuple, Type | |
import jax | |
import rich.console | |
import rich.live | |
import rich.progress | |
_ProgressCallable = Callable[[float, str, bool], Any] | |
class Progress(metaclass=abc.ABCMeta): | |
"""Abstract class for implementing a progress bar.""" | |
def __init__(self, total: float, description: str, ordered: bool): | |
self._total = total | |
self._description = description | |
self._ordered = ordered | |
@abc.abstractmethod | |
def _start(self): | |
pass | |
@abc.abstractmethod | |
def _update(self, value: float): | |
pass | |
@abc.abstractmethod | |
def _stop(self): | |
pass | |
def start(self): | |
jax.debug.callback(self._start, ordered=self._ordered) | |
def update(self, value): | |
jax.debug.callback(self._update, value, ordered=self._ordered) | |
def stop(self): | |
jax.debug.callback(self._stop, ordered=self._ordered) | |
def __enter__(self): | |
self.start() | |
return self | |
def __exit__( | |
self, | |
exc_type: Optional[Type[BaseException]], | |
exc_val: Optional[BaseException], | |
exc_tb: Optional[types.TracebackType], | |
): | |
self.stop() | |
class RichProgress(Progress): | |
"""Progress bar implemented using the `rich` library.""" | |
_live = rich.live.Live() | |
_active_progress: List[rich.progress.Progress] = [] | |
def __init__(self, total: float, description: str, ordered: bool): | |
super().__init__(total=total, description=description, ordered=ordered) | |
self._progress = rich.progress.Progress(auto_refresh=False) | |
self._task = self._progress.add_task(self._description, total=total) | |
self._lock = threading.Lock() | |
def _start(self): | |
if not self._active_progress: | |
self._live.start() | |
self._active_progress.append(self._progress) | |
self._live.update(rich.console.Group(*self._active_progress)) | |
def _update(self, value): | |
with self._lock: | |
self._progress.update(self._task, completed=value) | |
self._progress.refresh() | |
def _stop(self): | |
self._active_progress.pop() | |
self._live.update(rich.console.Group(*self._active_progress)) | |
if not self._active_progress: | |
self._live.stop() | |
_progress_bars: Set[Tuple[int, _ProgressCallable]] = set() | |
def get_progress_bar() -> _ProgressCallable: | |
progress_bars = sorted(_progress_bars, key=lambda x: -x[0]) | |
for _, progress_bar, in progress_bars: | |
return progress_bar | |
assert False, 'Should have hit default progress_bar' | |
def register_progress_bar(progress_bar, priority): | |
_progress_bars.add((priority, progress_bar)) | |
return progress_bar | |
def _rich_progress(total: float, description: str, ordered: bool): | |
return RichProgress(total=total, description=description, ordered=ordered) | |
register_progress_bar(_rich_progress, -1) | |
def progress(total: float, | |
description: str = 'Progress:', | |
ordered: bool = True): | |
return get_progress_bar()(total, description, ordered) | |
if __name__ == "__main__": | |
from jax import lax | |
import jax.numpy as jnp | |
import time | |
@jax.jit | |
def f(x): | |
with progress(total=20, description="Outer", ordered=True) as outer_pbar: | |
def outer_body(i, x): | |
outer_pbar.update(i) | |
with progress(total=10, description="Inner", ordered=True) as inner_pbar: | |
def inner_body(j, x): | |
inner_pbar.update(j) | |
def _sleep(x): | |
time.sleep(0.1) | |
return x | |
x = jax.pure_callback(_sleep, x, x) | |
return x | |
x = lax.fori_loop(0, 10, inner_body, x) | |
return x + 1 | |
return lax.fori_loop(0, 20, outer_body, x) | |
f(1.) | |
jax.effects_barrier() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment