diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index f925a6fc93dcd..35249cd7302cb 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -45,6 +45,7 @@ def _init_worker(self): rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7377c8931cefa..a82373d3d1626 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -5,7 +5,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -29,6 +29,7 @@ def __init__( device_config: DeviceConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, @@ -38,6 +39,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker @@ -59,13 +61,14 @@ def __init__( self.block_size: int # Set after initial profiling. def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=None, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def _prepare_prompt( self, @@ -76,6 +79,7 @@ def _prepare_prompt( input_positions: List[int] = [] slot_mapping: List[int] = [] prompt_lens: List[int] = [] + multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -96,6 +100,10 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, prompt_len))) + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, @@ -118,6 +126,15 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, @@ -144,12 +161,8 @@ def _prepare_prompt( slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return ( - input_tokens, - input_positions, - attn_metadata, - prompt_lens, - ) + return (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_input) def _prepare_decode( self, @@ -336,14 +349,16 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata]: + multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) @@ -376,12 +391,8 @@ def prepare_input_tensors( perform_sampling=False, ) - return ( - input_tokens, - input_positions, - attn_metadata, - sampling_metadata, - ) + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, multi_modal_input) @torch.inference_mode() def execute_model( @@ -389,7 +400,8 @@ def execute_model( seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata + (input_tokens, input_positions, attn_metadata, sampling_metadata, + multi_modal_input ) = self.prepare_input_tensors(seq_group_metadata_list) model_executable = self.model @@ -399,6 +411,8 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": attn_metadata, } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3652830b7d519..83ededd742533 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,7 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -122,6 +123,7 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -135,21 +137,25 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.vision_language_config = vision_language_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = CPUModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - load_config=self.load_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + self.model_runner = CPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CPUCacheEngine