Skip to content

Instantly share code, notes, and snippets.

@cadedaniel
Created May 6, 2024 16:59
Show Gist options
  • Save cadedaniel/f8479bf5fa5543b946d2133b5db38c56 to your computer and use it in GitHub Desktop.
Save cadedaniel/f8479bf5fa5543b946d2133b5db38c56 to your computer and use it in GitHub Desktop.
draft model on TP=1 worker
from typing import List, Optional, Set, Tuple
import logging
import torch
from vllm.sequence import SamplerOutput, ExecuteModelData
from vllm.model_executor.parallel_utils.parallel_state import patch_tensor_parallel_group
from vllm.config import CacheConfig, ParallelConfig
from vllm.worker.base_worker import BaseWorker, Speculator
from vllm.worker.spec_util import SpeculativeProposals
logger = logging.getLogger(__name__)
class SingleTpWorker(Speculator, BaseWorker):
"""Class which allows a speculative draft model to run with tensor parallel
degree of 1, while target model runs with larger tensor parallel degree.
This reduces the overhead of small draft models.
This is implemented by changing vLLM's tensor parallel group to a group of
size 1 during forward passes.
"""
@classmethod
def maybe_wrap_worker(cls, worker, draft_parallel_config: ParallelConfig,
target_parallel_config: ParallelConfig):
"""Wrap the worker in a SingleTpWorker if necessary.
"""
draft_tp = draft_parallel_config.tensor_parallel_size
if draft_tp == target_parallel_config.tensor_parallel_size:
return worker
if draft_tp != 1:
raise ValueError("{cls} only supports tp=1, found "
f"{draft_tp=}")
logger.info(f"Wrapping {type(worker)} in {cls}")
return cls(worker)
def __init__(
self,
worker: BaseWorker,
):
self._worker = worker
self._single_tp_group = None
def init_model(self):
"""Initialize the model on all ranks.
This also creates a single-rank process group containing only the
self process.
"""
world_rank = torch.distributed.get_rank()
self._single_tp_group = torch.distributed.new_group([world_rank])
with patch_tensor_parallel_group(self._single_tp_group):
self._worker.init_model(should_init_distributed_env=False)
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
) -> Tuple[int, int]:
"""Profile the model on all ranks.
"""
with patch_tensor_parallel_group(self._single_tp_group):
return self._worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space)
def init_cache_engine(self, cache_config: CacheConfig):
"""Initialize the cache engine on all ranks.
"""
with patch_tensor_parallel_group(self._single_tp_group):
self._worker.init_cache_engine(cache_config)
@property
def model_config(self):
return self._worker.model_config
@property
def parallel_config(self):
return self._worker.parallel_config
@property
def model(self):
return self._worker.model
@property
def rank(self):
return self._worker.rank
@property
def max_model_len(self) -> int:
return self._worker.max_model_len
@property
def vocab_size(self) -> int:
return self._worker.vocab_size
@property
def device(self) -> torch.device:
return self._worker.device
def include_gpu_probs_tensor(self) -> None:
"""Include GPU probs tensor in sampler output.
"""
self._worker.include_gpu_probs_tensor()
def get_kv_size_bytes(self, block_size: int) -> int:
"""Get the size of the KV cache.
"""
return self._worker.get_kv_size_bytes(block_size)
def set_extra_seeds_to_generate(self, n: int) -> None:
"""Set the number of seeds to generate by the sampler for each sequence.
"""
return self._worker.set_extra_seeds_to_generate(n)
def set_sampler_entropy(self, sampler_entropy: List[int]) -> None:
"""Set the sampler entropy.
"""
return self._worker.set_sampler_entropy(sampler_entropy)
def get_metadata_cache_len(self) -> int:
"""Metadata cache not currently supported.
"""
return 0
def get_runtime_context(self) -> Optional[dict]:
return self._worker.get_runtime_context()
@torch.inference_mode()
def execute_model(
self,
execute_model_data: ExecuteModelData,
*,
return_python_output: bool = True) -> List[SamplerOutput]:
"""Execute the model separately on each rank.
"""
with patch_tensor_parallel_group(self._single_tp_group):
return self._worker.execute_model(
execute_model_data, return_python_output=return_python_output)
def get_spec_proposals(self, execute_model_data: ExecuteModelData, k: int,
max_model_len: int,
non_spec_seq_ids: Set[str]) -> SpeculativeProposals:
"""Get the valid and non-valid sequences for speculative decoding.
For the valid sequences, get the proposal token ids and probs from
the draft model.
"""
with patch_tensor_parallel_group(self._single_tp_group):
return self._worker.get_spec_proposals(execute_model_data, k,
max_model_len,
non_spec_seq_ids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment