Skip to content

Instantly share code, notes, and snippets.

@asehmi
Created November 27, 2023 19:14
Show Gist options
  • Save asehmi/6fc84f9a992fd272a88f6ff2f0d6705b to your computer and use it in GitHub Desktop.
Save asehmi/6fc84f9a992fd272a88f6ff2f0d6705b to your computer and use it in GitHub Desktop.
Streamlit multi-threaded task execution, with queues
import time
import random
from queue import Queue
import threading
import streamlit as st
from streamlit.runtime.scriptrunner import add_script_run_ctx
pre_msgs = []
result_msgs = []
post_msgs = []
c1, c2, c3 = st.columns(3)
pre_report = c1.empty()
post_report = c2.empty()
result_report = c3.empty()
def do_job(job, duration=3):
time.sleep(duration)
return f'Job: {job} | Run time: {duration} secs'
def reporting_task(status_q: Queue):
while True:
status = status_q.get()
if status == 'STOP!':
status_q.task_done()
break
if 'pre' in status:
pre_msgs.append(status['pre'])
pre_report.text_area('Pre task trace', '\n'.join(pre_msgs[::-1]), height=300)
elif 'post' in status:
post_msgs.append(status['post'])
post_report.text_area('Post task trace', '\n'.join(post_msgs[::-1]), height=300)
else:
result_msgs.append(status['result'])
result_report.text_area('Task result trace', '\n'.join(result_msgs[::-1]), height=300)
status_q.task_done()
time.sleep(0.1)
def job_task(job_q: Queue, status_q: Queue):
'''
this runs in a thread context
'''
while not job_q.empty():
job_spec = job_q.get()
job = job_spec['job']
data = job_spec['data']
status_q.put({'pre': f'Starting job: {job}'})
result = do_job(job, data)
status_q.put({'result': result})
status_q.put({'post': f'Finished job: {job}'})
job_q.task_done()
def run_in_parallel(num_jobs: range, task, data, n_threads=10):
'''
not running in a thread; running in main context, spawns threads.
'''
job_q = Queue()
for n in num_jobs:
duration = random.randint(1,data)
job_q.put({'job': n, 'data': duration})
status_q = Queue()
# create threads and add to streamlit script context before starting
job_threads = [threading.Thread(target=task, args=(job_q, status_q)) for _ in range(n_threads)]
for thread in job_threads:
add_script_run_ctx(thread)
thread.start()
reporting_thread = threading.Thread(target=reporting_task, args=(status_q,), daemon=True)
add_script_run_ctx(reporting_thread)
reporting_thread.start()
for thread in job_threads:
thread.join()
status_q.put({'post': '--------------------------------'})
status_q.put({'post': f'Completed {len(num_jobs)} jobs!'})
status_q.put({'post': '--------------------------------'})
status_q.put('STOP!')
reporting_thread.join()
with st.form('form'):
num_jobs = st.number_input('Enter number of jobs to run', value=50, min_value=1, max_value=1000)
job_data = st.number_input('Enter job max allowed duration', value=3, min_value=1, max_value=10)
num_threads = st.number_input('Enter number of threads to run', value=5, min_value=1, max_value=100)
if st.form_submit_button('Submit', type='primary'):
run_in_parallel(range(1, num_jobs+1), job_task, job_data, n_threads=num_threads)
@asehmi
Copy link
Author

asehmi commented Nov 27, 2023

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment