Skip to content

Instantly share code, notes, and snippets.

@zhouyuan
Created April 9, 2024 02:34
Show Gist options
  • Save zhouyuan/f34ab80d905843ee3b8ef91c3418e531 to your computer and use it in GitHub Desktop.
Save zhouyuan/f34ab80d905843ee3b8ef91c3418e531 to your computer and use it in GitHub Desktop.
diff --git a/Dockerfile.cpu b/Dockerfile.cpu
index 4251fdd..d9344a0 100644
--- a/Dockerfile.cpu
+++ b/Dockerfile.cpu
@@ -6,6 +6,8 @@ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+ENV https_proxy="http://proxy-shz.intel.com:911"
+ENV http_proxy="http://proxy-shz.intel.com:911"
RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index d6bf18c..9e5f1b3 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -133,7 +133,8 @@ def run_hf(
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
- llm = llm.cuda()
+ llm = llm.cpu()
+ #llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
@@ -158,7 +159,7 @@ def run_hf(
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
- input_ids=input_ids.cuda(),
+ input_ids=input_ids.cpu(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
@@ -329,7 +330,7 @@ if __name__ == "__main__":
"--device",
type=str,
default="cuda",
- choices=["cuda"],
+ choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index f610495..6b88c49 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -333,6 +333,13 @@ class AsyncLLMEngine:
if engine_config.device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
+ if engine_config.device_config.device_type == "cpu":
+ if (engine_config.parallel_config.worker_use_ray
+ or engine_args.engine_use_ray):
+ logger.warning("not support ray yet")
+ else:
+ from vllm.executor.cpu_executor import CPUExecutorAsync
+ executor_class = CPUExecutorAsync
elif (engine_config.parallel_config.worker_use_ray
or engine_args.engine_use_ray):
initialize_ray_cluster(engine_config.parallel_config)
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 7b3cc78..9a5cc44 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -5,7 +5,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
-from vllm.executor.executor_base import ExecutorBase
+from vllm.executor.executor_base import (ExecutorBase, ExecutorAsyncBase)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@@ -118,6 +118,27 @@ class CPUExecutor(ExecutorBase):
return
+class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
+
+ async def execute_model_async(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ blocks_to_swap_in: Dict[int, int],
+ blocks_to_swap_out: Dict[int, int],
+ blocks_to_copy: Dict[int, List[int]],
+ ) -> SamplerOutput:
+ output = await make_async(self.driver_worker.execute_model)(
+ seq_group_metadata_list=seq_group_metadata_list,
+ blocks_to_swap_in=blocks_to_swap_in,
+ blocks_to_swap_out=blocks_to_swap_out,
+ blocks_to_copy=blocks_to_copy)
+ return output
+
+ async def check_health_async(self) -> None:
+ # XPUExecutor will always be healthy as long as
+ # it's running.
+ return
+
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 48facb5..11d9069 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -93,6 +93,9 @@ class Worker:
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
+ elif self.device_config.device == "cpu":
+ self.rank = 0
+ self.device = torch.device("cpu")
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment