Skip to content

Commit

Permalink
[Misc] Add vision language model support to CPU backend (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Apr 22, 2024
1 parent 747b1a7 commit 296cdf8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 32 deletions.
1 change: 1 addition & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _init_worker(self):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
Expand Down
60 changes: 37 additions & 23 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
Expand All @@ -29,6 +29,7 @@ def __init__(
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker

Expand All @@ -59,13 +61,14 @@ def __init__(
self.block_size: int # Set after initial profiling.

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=None,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)

def _prepare_prompt(
self,
Expand All @@ -76,6 +79,7 @@ def _prepare_prompt(
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []

for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
Expand All @@ -96,6 +100,10 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))

if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)

# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
Expand All @@ -118,6 +126,15 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None

num_prompt_tokens = len(input_tokens)

input_tokens = torch.tensor(input_tokens,
Expand All @@ -144,12 +161,8 @@ def _prepare_prompt(
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (
input_tokens,
input_positions,
attn_metadata,
prompt_lens,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input)

def _prepare_decode(
self,
Expand Down Expand Up @@ -336,14 +349,16 @@ def prepare_input_tensors(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
Expand Down Expand Up @@ -376,20 +391,17 @@ def prepare_input_tensors(
perform_sampling=False,
)

return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)

@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)

model_executable = self.model
Expand All @@ -399,6 +411,8 @@ def execute_model(
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})

hidden_states = model_executable(**execute_model_kwargs)

Expand Down
24 changes: 15 additions & 9 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
Expand All @@ -135,21 +137,25 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
self.model_runner = CPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CPUCacheEngine
Expand Down

0 comments on commit 296cdf8

Please sign in to comment.