Skip to content

Instantly share code, notes, and snippets.

@allenwang28
Last active January 24, 2024 22:09
Show Gist options
  • Save allenwang28/8e78cfd6d3d1713670fa69ebeb4ed023 to your computer and use it in GitHub Desktop.
Save allenwang28/8e78cfd6d3d1713670fa69ebeb4ed023 to your computer and use it in GitHub Desktop.
from typing import List, Tuple, Callable, Any, Dict, Optional
import threading
import queue
import time
import uuid
import event_logger
import logging
import ray
class BatcherTask(threading.Thread):
def __init__(
self,
request_queue: queue.Queue,
batch_queue: queue.Queue,
target_batch_size: int,
timeout_in_s: float,
batching_type: Optional[str] = None):
"""
Initializes the BatcherTask thread.
Args:
request_queue (queue.Queue): Queue for incoming requests.
batch_queue (queue.Queue): Queue for outgoing batches.
target_batch_size (int): Target number of requests per batch.
timeout_in_s (float): Timeout in seconds for batch compilation.
"""
super().__init__()
self.request_queue = request_queue
self.batch_queue = batch_queue
self.target_batch_size = target_batch_size
self.timeout_in_s = timeout_in_s
self.stop_event = threading.Event()
self.batching_type = batching_type
def run(self) -> None:
"""
Run method for the thread. Continuously compiles batches from incoming requests.
"""
batch_index = 0
while not self.stop_event.is_set():
start_time = time.time()
batch = []
id_batch = []
event_logger.event_start(
owner_name=f"Batcher_{self.batching_type}",
event_category="build_batch",
event_id=batch_index,
)
while len(batch) < self.target_batch_size and time.time() - start_time < self.timeout_in_s:
try:
timeout = min(self.timeout_in_s, time.time() - start_time)
request_data, request_id = self.request_queue.get(timeout=timeout)
batch.append(request_data)
id_batch.append(request_id)
except queue.Empty:
break
if batch:
event_logger.event_stop(
owner_name=f"Batcher_{self.batching_type}",
event_category="build_batch",
event_id=batch_index,
)
self.batch_queue.put((batch, id_batch))
event_logger.event_start(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_transfer",
)
batch_index += 1
def stop(self) -> None:
"""
Stops the thread.
"""
self.stop_event.set()
class BatchProcessorTask(threading.Thread):
def __init__(
self,
batch_queue: queue.Queue,
result_map: 'BatcherResults',
batch_processor: Callable[[List[Any]], List[Any]],
batching_type: Optional[str] = None):
"""
Initializes the BatchProcessorTask thread.
Args:
batch_queue (queue.Queue): Queue containing batches to process.
result_map (BatcherResults): Shared object for storing results.
batch_processor (Callable): Function to process each batch.
"""
super().__init__()
self.batch_queue = batch_queue
self.result_map = result_map
self.batch_processor = batch_processor
self.stop_event = threading.Event()
self.batching_type = batching_type
def run(self) -> None:
"""
Run method for the thread. Processes each batch and stores the results.
"""
batch_index = 0
while not self.stop_event.is_set():
try:
batch, id_batch = self.batch_queue.get()
event_logger.event_stop(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_transfer",
event_id=batch_index,
)
event_logger.event_start(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_process",
event_id=batch_index,
)
results = self.batch_processor(batch)
event_logger.event_stop(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_process",
event_id=batch_index,
)
event_logger.event_start(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_postprocess",
event_id=batch_index,
)
for result, request_id in zip(results, id_batch):
self.result_map.set(request_id, result)
event_logger.event_stop(
owner_name=f"Batcher_{self.batching_type}",
event_category="batch_postprocess",
event_id=batch_index,
)
batch_index += 1
except queue.Empty:
continue
def stop(self) -> None:
"""
Stops the thread.
"""
self.stop_event.set()
class BatcherResults:
def __init__(self):
"""
Initializes a thread-safe storage for batch results.
"""
self._d = {}
self._lock = threading.Lock()
def set(self, k: Any, v: Any) -> None:
"""
Sets a result for a specific request ID.
Args:
k (Any): Request ID.
v (Any): Result to be stored.
"""
with self._lock:
self._d[k] = v
def get(self, k: Any) -> Any:
"""
Retrieves a result for a given request ID.
Args:
k (Any): Request ID.
Returns:
Any: Result associated with the request ID.
"""
with self._lock:
return self._d.get(k)
def print(self) -> None:
"""
Prints the current state of the results storage.
"""
with self._lock:
print(self._d)
@ray.remote
class RequestBatcher:
def __init__(
self,
batch_handler_fn: Callable[[List[Any]], List[Any]],
batch_size: int,
batch_timeout_s: float,
batching_type: Optional[str] = None,
):
"""
Initializes the RequestBatcher which is responsible for managing the lifecycle
of batch processing, from receiving individual requests to returning processed results.
Args:
batch_handler_fn (Callable): The function that will process each batch of requests.
batch_size (int): The target number of requests to include in each batch.
batch_timeout_s (int): The maximum time to wait before processing a batch,
even if it hasn't reached the target batch size.
"""
self._batching_type = batching_type
self.batch_handler_fn = batch_handler_fn
self.batch_size = batch_size
self.batch_timeout_s = batch_timeout_s
self._request_queue = queue.Queue()
self._batch_queue = queue.Queue()
self._result_map = BatcherResults()
self._batch_task = BatcherTask(
request_queue=self._request_queue,
batch_queue=self._batch_queue,
target_batch_size=self.batch_size,
timeout_in_s=self.batch_timeout_s,
batching_type=batching_type)
self._batch_processor_task = BatchProcessorTask(
batch_queue=self._batch_queue,
result_map=self._result_map,
batch_processor=self.batch_handler_fn,
batching_type=batching_type)
self._batch_task.start()
self._batch_processor_task.start()
def put(self, input_data: Any) -> str:
"""
Adds a new request to the batcher.
Args:
input_data (Any): The data for the request to be processed.
Returns:
str: A unique identifier for the request, which can be used to retrieve the result later.
"""
request_id = str(uuid.uuid4())
self._request_queue.put((input_data, request_id))
return request_id
def get(self, request_id: str) -> Any:
"""
Retrieves the result for a given request ID.
Args:
request_id (str): The unique identifier for the request.
Returns:
Any: The result of the processing, or None if no result is available yet.
"""
while True:
result = self._result_map.get(request_id)
if result is not None:
return result
def exit(self) -> None:
"""
Cleans up and stops the threads associated with the batch processing.
"""
print("Exiting")
self._batch_task.stop()
self._batch_processor_task.stop()
self._batch_task.join()
self._batch_processor_task.join()
print("Done Exiting")
from typing import List, Any, Iterable
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax._src.mesh import Mesh
from jax.sharding import PositionalSharding
import logging
import socket
import ray
import networks
import config
import checkpoint_utils
import array_encode_decode
import event_logger
from batching import RequestBatcher
@ray.remote(resources={"inference_cpu_handler": 1})
class RayInferenceActor:
def __init__(
self,
ckpt_dir: str,
batch_size: int = 256,
model="repr",
batch_timeout_s: float = 1.0,
tpu_id: str = "inference",
weight_update_interval: int = 100,
):
self.batch_size = batch_size
self.batch_timeout_s = batch_timeout_s
self.ckpt_dir = ckpt_dir
self.tpu_id = tpu_id
self.step_counter = 0
self.weight_update_interval = weight_update_interval
self._model = model
if model == "repr":
self.actor_def = RayReprInferenceShard
elif model == "dyna":
self.actor_def = RayDynaInferenceShard
else:
self._actor_def = None
print("Invalid model provided...")
def initialize(self):
if not self.actor_def:
raise ValueError("Actor def is not defined. Check the provided model.")
if "v4_16" in self.tpu_id:
num_hosts = 2
else:
num_hosts = 1
logging.info("Number of hosts: %d", num_hosts)
print("Number of hosts: ", num_hosts)
self._shards = [
self.actor_def.options(resources={"TPU": 4}).remote(ckpt_dir=self.ckpt_dir)
for _ in range(num_hosts)
]
init_handles = [shard.initialize.remote() for shard in self._shards]
self._batcher = RequestBatcher.remote(
batch_handler_fn=self.process_batch,
batch_size=self.batch_size,
batch_timeout_s=self.batch_timeout_s,
batching_type=self._model)
ray.get(init_handles)
def process_batch(self, inputs: Iterable[Any]):
event_logger.event_start(
owner_name=f"{self._model}-inference",
event_category="process_batch",
)
try:
if (
self.step_counter > 0
and self.step_counter % self.weight_update_interval == 0
):
print("updating weights")
start_time = time.time()
ray.get([shard.update_weights.remote() for shard in self._shards])
print(f"update weight time time {time.time() - start_time}s")
results = ray.get(
[shard.handle_batch.remote(inputs) for shard in self._shards]
)
final_result = results[0]
self.step_counter += 1
# Results will be a list of list of results
# We will need to flatten this when we have more than one host
# but for now let's just return the first index in single host
# case.
# check this again for multihost
event_logger.event_stop(
owner_name=f"{self._model}-inference",
event_category="process_batch",
)
return final_result
except Exception as e:
print("process_batch failed due to: ", e)
raise e
def put(self, input: Any) -> str:
return ray.get(self._batcher.put.remote(input))
def get(self, request_id: str) -> Any:
event_logger.event_start(
owner_name=f"{self._model}-inference",
event_category="get_request",
)
result = ray.get(self._batcher.get.remote(request_id))
event_logger.event_stop(
owner_name=f"{self._model}-inference",
event_category="get_request",
)
return result
class RayInferenceShardBase:
"""Base class for Ray Inference Shards."""
def __init__(self, ckpt_dir: str, skip_checkpoint: bool = False):
"""Ray-based Dynamic inferencer."""
self.ckpt_dir = ckpt_dir
self._skip_checkpoint = skip_checkpoint
def update_weights(self) -> None:
print("beginning weight update")
all_steps = self._ckpt_manager.all_steps(read=True)
latest_step = max(all_steps) if all_steps else None
logging.info(f"actor latest_ckpt_step={latest_step}")
print(f"actor latest_ckpt_step={latest_step}")
self.step = latest_step
if latest_step:
count_try = 0
while True:
if count_try > 3:
return False
try:
print("Trying to restore checkpoint")
restored = self._ckpt_manager.restore(latest_step)
print("Initial restore complete")
restored_params = restored["save_state"]["state"]
restored_step = restored_params["step"]
print("Restored step ", restored_step)
logging.info(f"actor restored_ckpt_step={restored_step}")
self._latest_step = restored_step
self._model_params_states = restored_params
if self._latest_step - 1 >= self._total_training_steps:
self._finished = True
return True
except Exception:
count_try += 1
logging.info("waiting for 30s and retry updating actor.")
time.sleep(30)
else:
return False
@ray.remote
class RayDynaInferenceShard(RayInferenceShardBase):
def __repr__(self):
return f"[RayDynaInferenceActorShard:{socket.gethostname()}]"
def initialize(self):
model_config = config.ModelConfig()
network = networks.get_model(model_config)
def dyna_and_pred(params, embedding, action):
return network.apply(
params, embedding, action, method=network.dyna_and_pred
)
self._jitted_dyna_and_pred = jax.jit(dyna_and_pred)
self._num_devices = jax.device_count()
self._emb_sharding = PositionalSharding(jax.devices()).reshape(
self._num_devices, 1, 1, 1
)
self._action_sharding = PositionalSharding(jax.devices()).reshape(
self._num_devices
)
self._mesh = Mesh(np.asarray(jax.devices(), dtype=object), ["data"])
if self._skip_checkpoint:
print("Skipping checkpoint loading...")
_, key2 = jax.random.split(jax.random.PRNGKey(42))
dummy_obs = jnp.zeros((1, 96, 96, 3, 4), dtype=jnp.float32)
dummy_action = jnp.zeros((1, 1), dtype=jnp.int32)
params = network.init(key2, dummy_obs, dummy_action)
self._model_params_states = params
else:
dummy_ckpt_save_interval_steps = 10
while True:
try:
self._ckpt_manager = checkpoint_utils.get_ckpt_manager(
self.ckpt_dir,
dummy_ckpt_save_interval_steps,
create=False,
use_async=False,
)
print("got ckpt manager")
break
except Exception:
print("waiting for 30s and retry.")
time.sleep(30)
all_steps = self._ckpt_manager.all_steps(read=True)
latest_step = max(all_steps) if all_steps else None
if latest_step is None:
latest_step = 0
print(f"need to load actor latest_ckpt_step={latest_step}")
while True:
try:
restored = self._ckpt_manager.restore(latest_step)
restored_params = restored["save_state"]["state"]
self._model_params_states = restored_params
print("done restoring")
break
except Exception:
print(f"trying to load {latest_step} again")
time.sleep(30)
def batch_and_shard(self, inputs: Iterable[Any]):
embeddings = []
actions = []
for embedding, action in inputs:
embedding = array_encode_decode.ndarray_from_bytes(embedding)
action = array_encode_decode.ndarray_from_bytes(action)
embeddings.append(embedding)
actions.append(action)
num_to_pad = (
self._num_devices - (len(inputs) % self._num_devices)
) % self._num_devices
for i in range(num_to_pad):
embeddings.append(array_encode_decode.ndarray_from_bytes(inputs[0][0]))
actions.append(array_encode_decode.ndarray_from_bytes(inputs[0][1]))
global_embedding = np.concatenate(embeddings, axis=0)
global_embedding_shape = global_embedding.shape
embedding_arrays = [
jax.device_put(global_embedding[index], d)
for d, index in self._emb_sharding.addressable_devices_indices_map(
global_embedding_shape
).items()
]
sharded_embedding = jax.make_array_from_single_device_arrays(
global_embedding_shape, self._emb_sharding, embedding_arrays
)
global_actions = np.concatenate(actions, axis=0)
global_actions_shape = global_actions.shape
action_arrays = [
jax.device_put(global_actions[index], d)
for d, index in self._action_sharding.addressable_devices_indices_map(
global_actions_shape
).items()
]
sharded_actions = jax.make_array_from_single_device_arrays(
global_actions_shape, self._action_sharding, action_arrays
)
return (sharded_embedding, sharded_actions)
def handle_batch(self, inputs: Iterable[Any]) -> List[Any]:
def print_shape_and_type(x):
return x.shape, x.dtype
embedding, action = self.batch_and_shard(inputs)
dp_net_out = self._jitted_dyna_and_pred(
{"params": self._model_params_states["params"]}, embedding, action
)
jax.block_until_ready(dp_net_out)
dp_net_out = jax.experimental.multihost_utils.process_allgather(
dp_net_out, tiled=True
)
result = dp_net_out[1]
result["embedding"] = dp_net_out[0]
result_list = jax.tree_util.tree_map(list, result)
final_result = []
for i in range(len(inputs)):
result = {
key: array_encode_decode.ndarray_to_bytes(
np.asarray(result_list[key][i])
)
for key in result_list
}
final_result.append(result)
return final_result
@ray.remote
class RayReprInferenceShard(RayInferenceShardBase):
def __repr__(self):
return f"[RayReprInferenceActorShard:{socket.gethostname()}]"
def initialize(self):
model_config = config.ModelConfig()
network = networks.get_model(model_config)
def repr_and_pred(params, obs, dtype=jnp.float32):
return network.apply(params, obs, dtype, method=network.repr_and_pred)
self._jitted_repr_and_pred = jax.jit(repr_and_pred)
self._num_devices = jax.device_count()
self._obs_sharding = PositionalSharding(jax.devices()).reshape(
self._num_devices, 1, 1, 1, 1
)
self._mesh = Mesh(np.asarray(jax.devices(), dtype=object), ["data"])
dummy_ckpt_save_interval_steps = 10
while True:
try:
print("try to get ckpt manager")
self._ckpt_manager = checkpoint_utils.get_ckpt_manager(
self.ckpt_dir,
dummy_ckpt_save_interval_steps,
create=False,
use_async=False,
)
print("got ckpt manager")
break
except Exception:
print("waiting for 30s and reetry.")
time.sleep(30)
all_steps = self._ckpt_manager.all_steps(read=True)
latest_step = max(all_steps) if all_steps else None
self.step = latest_step
if latest_step is None:
latest_step = 0
print(f"need to load actor latest_ckpt_step={latest_step}")
while True:
try:
restored = self._ckpt_manager.restore(latest_step)
restored_params = restored["save_state"]["state"]
self._model_params_states = restored_params
print("done restoring")
break
except Exception:
print(f"trying to load {latest_step} again")
time.sleep(30)
def batch_and_shard(self, inputs: Iterable[any]):
observations = []
for observation in inputs:
observation = array_encode_decode.ndarray_from_bytes(observation)
observations.append(observation)
num_to_pad = (
self._num_devices - (len(inputs) % self._num_devices)
) % self._num_devices
for i in range(num_to_pad):
observations.append(array_encode_decode.ndarray_from_bytes(inputs[0]))
global_observation = np.concatenate(observations, axis=0)
global_observation_shape = global_observation.shape
observation_arrays = [
jax.device_put(global_observation[index], d)
for d, index in self._obs_sharding.addressable_devices_indices_map(
global_observation_shape
).items()
]
shared_observation = jax.make_array_from_single_device_arrays(
global_observation_shape, self._obs_sharding, observation_arrays
)
return shared_observation
def handle_batch(self, inputs: Any) -> List[Any]:
def print_shape_and_type(x):
return x.shape, x.dtype
observation = self.batch_and_shard(inputs)
repr_net_out = self._jitted_repr_and_pred(
{"params": self._model_params_states["params"]}, observation
)
jax.block_until_ready(repr_net_out)
repr_net_out = jax.experimental.multihost_utils.process_allgather(
repr_net_out, tiled=True
)
result = repr_net_out[1]
result["embedding"] = repr_net_out[0]
result_list = jax.tree_util.tree_map(list, result)
final_result = []
for i in range(len(inputs)):
result = {
key: array_encode_decode.ndarray_to_bytes(result_list[key][i])
for key in result_list
}
result["step"] = self.step
final_result.append(result)
return final_result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment