Skip to content

Instantly share code, notes, and snippets.

@g-simmons
Created February 15, 2023 21:59
Show Gist options
  • Save g-simmons/d01db39bbb6aa69452db8e4c66efc1b0 to your computer and use it in GitHub Desktop.
Save g-simmons/d01db39bbb6aa69452db8e4c66efc1b0 to your computer and use it in GitHub Desktop.
Rich progress bars for Joblib parallel tasks
import numpy as np
from joblib import Parallel, delayed
from threading import Thread
from rich.progress import Progress, BarColumn, TimeRemainingColumn, TextColumn
from rich.console import Console
from rich.live import Live
import time
# Define the number of tasks and create a shared memory numpy array to hold their progress
num_tasks = 4
progress_array = np.memmap(
"progress.mmap2", dtype=np.float32, mode="w+", shape=num_tasks
)
# Define a function that performs a task and updates the progress array
def perform_task(task_idx, progress_array):
for i in range(100):
# Do some work here
# ...
# Update the progress array
time.sleep(0.1)
progress_array[task_idx] = i / 100
# Update the progress array to 100% on completion
progress_array[task_idx] = 1
# Create a console for the Rich progress bar
console = Console()
# Define a function to continuously update the Rich progress bar
def update_progress_bar(
progress_array=progress_array,
console=console,
num_tasks=num_tasks,
):
with Live(
refresh_per_second=4,
transient=True,
console=console,
):
with Progress(
BarColumn(),
TextColumn("[bold green]{task.fields[status]}"),
TextColumn("[bold blue]{task.fields[name]}"),
TimeRemainingColumn(),
# console=console,
) as progress:
tasks = [
progress.add_task(
description=f"Task {i}",
name=f"Task {i}",
status="pending",
total=100,
)
for i in range(num_tasks)
]
while not all(progress_array == 1):
for i, task in enumerate(tasks):
progress.update(task, completed=int(progress_array[i] * 100))
time.sleep(0.1)
# Launch the progress bar update function in a separate thread
Thread(target=update_progress_bar, args=[progress_array, console, num_tasks]).start()
# Launch the tasks in parallel using joblib and the perform_task function
Parallel(n_jobs=-8, backend="loky")(
delayed(perform_task)(i, progress_array) for i in range(num_tasks)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment