Skip to content

Instantly share code, notes, and snippets.

@allenwang28
Last active January 24, 2024 21:10
Show Gist options
  • Save allenwang28/16506e88a007bedc156f8e24383fd042 to your computer and use it in GitHub Desktop.
Save allenwang28/16506e88a007bedc156f8e24383fd042 to your computer and use it in GitHub Desktop.
from typing import List, Tuple, Callable, Any, Dict
import threading
import queue
import time
import uuid
import logging
class BatcherTask(threading.Thread):
def __init__(self, request_queue: queue.Queue, batch_queue: queue.Queue, target_batch_size: int, timeout_in_s: int):
"""
Initializes the BatcherTask thread.
Args:
request_queue (queue.Queue): Queue for incoming requests.
batch_queue (queue.Queue): Queue for outgoing batches.
target_batch_size (int): Target number of requests per batch.
timeout_in_s (int): Timeout in seconds for batch compilation.
"""
super().__init__()
self.request_queue = request_queue
self.batch_queue = batch_queue
self.target_batch_size = target_batch_size
self.timeout_in_s = timeout_in_s
self.stop_event = threading.Event()
def run(self) -> None:
"""
Run method for the thread. Continuously compiles batches from incoming requests.
"""
while not self.stop_event.is_set():
start_time = time.time()
batch = []
id_batch = []
while len(batch) < self.target_batch_size and time.time() - start_time < self.timeout_in_s:
try:
request_data, request_id = self.request_queue.get(timeout=self.timeout_in_s - (time.time() - start_time))
batch.append(request_data)
id_batch.append(request_id)
except queue.Empty:
break
if batch:
self.batch_queue.put((batch, id_batch))
def stop(self) -> None:
"""
Stops the thread.
"""
self.stop_event.set()
class BatchProcessorTask(threading.Thread):
def __init__(self, batch_queue: queue.Queue, result_map: 'BatcherResults', batch_processor: Callable[[List[Any]], List[Any]]):
"""
Initializes the BatchProcessorTask thread.
Args:
batch_queue (queue.Queue): Queue containing batches to process.
result_map (BatcherResults): Shared object for storing results.
batch_processor (Callable): Function to process each batch.
"""
super().__init__()
self.batch_queue = batch_queue
self.result_map = result_map
self.batch_processor = batch_processor
self.stop_event = threading.Event()
def run(self) -> None:
"""
Run method for the thread. Processes each batch and stores the results.
"""
while not self.stop_event.is_set():
try:
batch, id_batch = self.batch_queue.get()
results = self.batch_processor(batch)
for result, request_id in zip(results, id_batch):
self.result_map.set(request_id, result)
except queue.Empty:
continue
def stop(self) -> None:
"""
Stops the thread.
"""
self.stop_event.set()
class BatcherResults:
def __init__(self):
"""
Initializes a thread-safe storage for batch results.
"""
self._d = {}
self._lock = threading.Lock()
def set(self, k: Any, v: Any) -> None:
"""
Sets a result for a specific request ID.
Args:
k (Any): Request ID.
v (Any): Result to be stored.
"""
with self._lock:
self._d[k] = v
def get(self, k: Any) -> Any:
"""
Retrieves a result for a given request ID.
Args:
k (Any): Request ID.
Returns:
Any: Result associated with the request ID.
"""
with self._lock:
return self._d.get(k)
def print(self) -> None:
"""
Prints the current state of the results storage.
"""
with self._lock:
print(self._d)
class RequestBatcher:
def __init__(self, batch_handler_fn: Callable[[List[Any]], List[Any]], batch_size: int, batch_timeout_s: int):
"""
Initializes the RequestBatcher which is responsible for managing the lifecycle
of batch processing, from receiving individual requests to returning processed results.
Args:
batch_handler_fn (Callable): The function that will process each batch of requests.
batch_size (int): The target number of requests to include in each batch.
batch_timeout_s (int): The maximum time to wait before processing a batch,
even if it hasn't reached the target batch size.
"""
self.batch_handler_fn = batch_handler_fn
self.batch_size = batch_size
self.batch_timeout_s = batch_timeout_s
self._request_queue = queue.Queue()
self._batch_queue = queue.Queue()
self._result_map = BatcherResults()
self._batch_task = BatcherTask(self._request_queue, self._batch_queue, self.batch_size, self.batch_timeout_s)
self._batch_processor_task = BatchProcessorTask(self._batch_queue, self._result_map, self.batch_handler_fn)
self._batch_task.start()
self._batch_processor_task.start()
def put(self, input_data: Any) -> str:
"""
Adds a new request to the batcher.
Args:
input_data (Any): The data for the request to be processed.
Returns:
str: A unique identifier for the request, which can be used to retrieve the result later.
"""
request_id = str(uuid.uuid4())
self._request_queue.put((input_data, request_id))
return request_id
def get(self, request_id: str) -> Any:
"""
Retrieves the result for a given request ID.
Args:
request_id (str): The unique identifier for the request.
Returns:
Any: The result of the processing, or None if no result is available yet.
"""
while True:
result = self._result_map.get(request_id)
if result is not None:
return result
time.sleep(0.1) # To prevent busy-waiting
def exit(self) -> None:
"""
Cleans up and stops the threads associated with the batch processing.
"""
print("Exiting")
self._batch_task.stop()
self._batch_processor_task.stop()
self._batch_task.join()
self._batch_processor_task.join()
print("Done Exiting")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment