Last active
January 24, 2024 21:10
-
-
Save allenwang28/16506e88a007bedc156f8e24383fd042 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Callable, Any, Dict | |
import threading | |
import queue | |
import time | |
import uuid | |
import logging | |
class BatcherTask(threading.Thread): | |
def __init__(self, request_queue: queue.Queue, batch_queue: queue.Queue, target_batch_size: int, timeout_in_s: int): | |
""" | |
Initializes the BatcherTask thread. | |
Args: | |
request_queue (queue.Queue): Queue for incoming requests. | |
batch_queue (queue.Queue): Queue for outgoing batches. | |
target_batch_size (int): Target number of requests per batch. | |
timeout_in_s (int): Timeout in seconds for batch compilation. | |
""" | |
super().__init__() | |
self.request_queue = request_queue | |
self.batch_queue = batch_queue | |
self.target_batch_size = target_batch_size | |
self.timeout_in_s = timeout_in_s | |
self.stop_event = threading.Event() | |
def run(self) -> None: | |
""" | |
Run method for the thread. Continuously compiles batches from incoming requests. | |
""" | |
while not self.stop_event.is_set(): | |
start_time = time.time() | |
batch = [] | |
id_batch = [] | |
while len(batch) < self.target_batch_size and time.time() - start_time < self.timeout_in_s: | |
try: | |
request_data, request_id = self.request_queue.get(timeout=self.timeout_in_s - (time.time() - start_time)) | |
batch.append(request_data) | |
id_batch.append(request_id) | |
except queue.Empty: | |
break | |
if batch: | |
self.batch_queue.put((batch, id_batch)) | |
def stop(self) -> None: | |
""" | |
Stops the thread. | |
""" | |
self.stop_event.set() | |
class BatchProcessorTask(threading.Thread): | |
def __init__(self, batch_queue: queue.Queue, result_map: 'BatcherResults', batch_processor: Callable[[List[Any]], List[Any]]): | |
""" | |
Initializes the BatchProcessorTask thread. | |
Args: | |
batch_queue (queue.Queue): Queue containing batches to process. | |
result_map (BatcherResults): Shared object for storing results. | |
batch_processor (Callable): Function to process each batch. | |
""" | |
super().__init__() | |
self.batch_queue = batch_queue | |
self.result_map = result_map | |
self.batch_processor = batch_processor | |
self.stop_event = threading.Event() | |
def run(self) -> None: | |
""" | |
Run method for the thread. Processes each batch and stores the results. | |
""" | |
while not self.stop_event.is_set(): | |
try: | |
batch, id_batch = self.batch_queue.get() | |
results = self.batch_processor(batch) | |
for result, request_id in zip(results, id_batch): | |
self.result_map.set(request_id, result) | |
except queue.Empty: | |
continue | |
def stop(self) -> None: | |
""" | |
Stops the thread. | |
""" | |
self.stop_event.set() | |
class BatcherResults: | |
def __init__(self): | |
""" | |
Initializes a thread-safe storage for batch results. | |
""" | |
self._d = {} | |
self._lock = threading.Lock() | |
def set(self, k: Any, v: Any) -> None: | |
""" | |
Sets a result for a specific request ID. | |
Args: | |
k (Any): Request ID. | |
v (Any): Result to be stored. | |
""" | |
with self._lock: | |
self._d[k] = v | |
def get(self, k: Any) -> Any: | |
""" | |
Retrieves a result for a given request ID. | |
Args: | |
k (Any): Request ID. | |
Returns: | |
Any: Result associated with the request ID. | |
""" | |
with self._lock: | |
return self._d.get(k) | |
def print(self) -> None: | |
""" | |
Prints the current state of the results storage. | |
""" | |
with self._lock: | |
print(self._d) | |
class RequestBatcher: | |
def __init__(self, batch_handler_fn: Callable[[List[Any]], List[Any]], batch_size: int, batch_timeout_s: int): | |
""" | |
Initializes the RequestBatcher which is responsible for managing the lifecycle | |
of batch processing, from receiving individual requests to returning processed results. | |
Args: | |
batch_handler_fn (Callable): The function that will process each batch of requests. | |
batch_size (int): The target number of requests to include in each batch. | |
batch_timeout_s (int): The maximum time to wait before processing a batch, | |
even if it hasn't reached the target batch size. | |
""" | |
self.batch_handler_fn = batch_handler_fn | |
self.batch_size = batch_size | |
self.batch_timeout_s = batch_timeout_s | |
self._request_queue = queue.Queue() | |
self._batch_queue = queue.Queue() | |
self._result_map = BatcherResults() | |
self._batch_task = BatcherTask(self._request_queue, self._batch_queue, self.batch_size, self.batch_timeout_s) | |
self._batch_processor_task = BatchProcessorTask(self._batch_queue, self._result_map, self.batch_handler_fn) | |
self._batch_task.start() | |
self._batch_processor_task.start() | |
def put(self, input_data: Any) -> str: | |
""" | |
Adds a new request to the batcher. | |
Args: | |
input_data (Any): The data for the request to be processed. | |
Returns: | |
str: A unique identifier for the request, which can be used to retrieve the result later. | |
""" | |
request_id = str(uuid.uuid4()) | |
self._request_queue.put((input_data, request_id)) | |
return request_id | |
def get(self, request_id: str) -> Any: | |
""" | |
Retrieves the result for a given request ID. | |
Args: | |
request_id (str): The unique identifier for the request. | |
Returns: | |
Any: The result of the processing, or None if no result is available yet. | |
""" | |
while True: | |
result = self._result_map.get(request_id) | |
if result is not None: | |
return result | |
time.sleep(0.1) # To prevent busy-waiting | |
def exit(self) -> None: | |
""" | |
Cleans up and stops the threads associated with the batch processing. | |
""" | |
print("Exiting") | |
self._batch_task.stop() | |
self._batch_processor_task.stop() | |
self._batch_task.join() | |
self._batch_processor_task.join() | |
print("Done Exiting") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment