diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 8cd938fc85fb2..bce91f1d7fd5e 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -41,6 +41,7 @@ from torch.nn.init import trunc_normal_ from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -154,15 +155,15 @@ class BaseResampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.num_queries = num_queries @@ -172,7 +173,11 @@ def __init__( self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=prefix) else: # Maintain the same return value with ReplicatedLinear.forward self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa @@ -209,22 +214,24 @@ class Resampler2(BaseResampler): present in minicpmv2.0, but not qwen-vl. """ - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, norm_layer, - do_post_projection=do_post_projection) + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix) self.adaptive = adaptive pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2cf4e92908353..07adf7c01eaaf 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -28,6 +28,7 @@ get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( @@ -771,6 +772,8 @@ def __init__(self, load_config: LoadConfig): with open(config_file_path, "r") as f: config = json.load(f) self.target_modules = config["target_modules"] + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] def _get_config_file(self, qlora_adapter: str) -> str: is_local = os.path.isdir(qlora_adapter) @@ -990,16 +993,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): weight_name = weight_name.replace(".weight", ".qweight") - - if any(module in weight_name - for module in self.column_parallel_weights_modules): + # Without sharding + if any( + weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any(module in weight_name + for module in self.column_parallel_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] - + # Shard by row else: total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank @@ -1053,7 +1061,15 @@ def _load_weights(self, model_config: ModelConfig, model.column_parallel_weights_modules else: self.column_parallel_weights_modules = [] - + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + # TODO: Can we reduce the static variables needed for BNB based on + # model information? + self.unsharded_weights_modules = [ + name for name, module in model.named_modules() + if isinstance(module, (ReplicatedLinear, )) + ] self.model_type = type(model).__name__ logger.info("Loading weights with BitsAndBytes quantization. " @@ -1100,7 +1116,13 @@ def _load_weights(self, model_config: ModelConfig, for shard_name, ( weight_name, index ) in model.bitsandbytes_stacked_params_mapping.items(): - if shard_name in quant_param_name: + + shard_pos = quant_param_name.find(shard_name) + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight + if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": shard_index = index quant_param_name = quant_param_name.replace( shard_name, weight_name) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e7088edb97b2b..c1f714bb25680 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): class Resampler2_5(BaseResampler): - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: Tuple[int, int] = (70, 70), - ) -> None: - super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix) self.max_size = max_size self._set_2d_pos_cache(self.max_size) @@ -404,7 +410,10 @@ def __init__( self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) + self.resampler = self.init_resampler(self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix="resampler") self.resampler.to(device="cuda", dtype=param_dtype) # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, @@ -666,7 +675,11 @@ def init_vision_module( ) -> nn.Module: raise NotImplementedError - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: raise NotImplementedError def get_vision_embedding( @@ -743,16 +756,21 @@ def init_vision_module( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_tokens(input_ids) - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2( - embed_dim=embed_dim, - num_heads=embed_dim // 128, - grid_size=int(math.sqrt(self.config.query_num)), - kv_dim=vision_dim, - adaptive=False, - do_post_projection=True, - ) + resampler = Resampler2(embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int( + math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix) return resampler @@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + # Currently, vllm does not support BNB quantization for the `out_proj` + # of the resampler, so it's necessary to distinguish between the + # vision encoder and the resampler's out_proj. The same applies to + # MiniCPMV2_6. + ".self_attn.out_proj.", # vision encoder out_proj + # resampler + ".kv_proj.", ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [ + ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." + ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -877,14 +907,18 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( @@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + ".self_attn.out_proj.", + # resampler + ".kv_proj.", ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [ + ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." + ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -1019,15 +1061,19 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 19c3827e43703..a03155ac32a61 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1056,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ".k_proj.", ".v_proj.", ".o_proj.", + ".fc1.", + ".fc2.", + # The `multi_modal_projector` is at the top level of the model, + # so we can't add a dot in front of it. + "multi_modal_projector." ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0),