Created
November 27, 2023 19:14
-
-
Save asehmi/6fc84f9a992fd272a88f6ff2f0d6705b to your computer and use it in GitHub Desktop.
Streamlit multi-threaded task execution, with queues
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import random | |
from queue import Queue | |
import threading | |
import streamlit as st | |
from streamlit.runtime.scriptrunner import add_script_run_ctx | |
pre_msgs = [] | |
result_msgs = [] | |
post_msgs = [] | |
c1, c2, c3 = st.columns(3) | |
pre_report = c1.empty() | |
post_report = c2.empty() | |
result_report = c3.empty() | |
def do_job(job, duration=3): | |
time.sleep(duration) | |
return f'Job: {job} | Run time: {duration} secs' | |
def reporting_task(status_q: Queue): | |
while True: | |
status = status_q.get() | |
if status == 'STOP!': | |
status_q.task_done() | |
break | |
if 'pre' in status: | |
pre_msgs.append(status['pre']) | |
pre_report.text_area('Pre task trace', '\n'.join(pre_msgs[::-1]), height=300) | |
elif 'post' in status: | |
post_msgs.append(status['post']) | |
post_report.text_area('Post task trace', '\n'.join(post_msgs[::-1]), height=300) | |
else: | |
result_msgs.append(status['result']) | |
result_report.text_area('Task result trace', '\n'.join(result_msgs[::-1]), height=300) | |
status_q.task_done() | |
time.sleep(0.1) | |
def job_task(job_q: Queue, status_q: Queue): | |
''' | |
this runs in a thread context | |
''' | |
while not job_q.empty(): | |
job_spec = job_q.get() | |
job = job_spec['job'] | |
data = job_spec['data'] | |
status_q.put({'pre': f'Starting job: {job}'}) | |
result = do_job(job, data) | |
status_q.put({'result': result}) | |
status_q.put({'post': f'Finished job: {job}'}) | |
job_q.task_done() | |
def run_in_parallel(num_jobs: range, task, data, n_threads=10): | |
''' | |
not running in a thread; running in main context, spawns threads. | |
''' | |
job_q = Queue() | |
for n in num_jobs: | |
duration = random.randint(1,data) | |
job_q.put({'job': n, 'data': duration}) | |
status_q = Queue() | |
# create threads and add to streamlit script context before starting | |
job_threads = [threading.Thread(target=task, args=(job_q, status_q)) for _ in range(n_threads)] | |
for thread in job_threads: | |
add_script_run_ctx(thread) | |
thread.start() | |
reporting_thread = threading.Thread(target=reporting_task, args=(status_q,), daemon=True) | |
add_script_run_ctx(reporting_thread) | |
reporting_thread.start() | |
for thread in job_threads: | |
thread.join() | |
status_q.put({'post': '--------------------------------'}) | |
status_q.put({'post': f'Completed {len(num_jobs)} jobs!'}) | |
status_q.put({'post': '--------------------------------'}) | |
status_q.put('STOP!') | |
reporting_thread.join() | |
with st.form('form'): | |
num_jobs = st.number_input('Enter number of jobs to run', value=50, min_value=1, max_value=1000) | |
job_data = st.number_input('Enter job max allowed duration', value=3, min_value=1, max_value=10) | |
num_threads = st.number_input('Enter number of threads to run', value=5, min_value=1, max_value=100) | |
if st.form_submit_button('Submit', type='primary'): | |
run_in_parallel(range(1, num_jobs+1), job_task, job_data, n_threads=num_threads) |
Author
asehmi
commented
Nov 27, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment