diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 46ef11e7d02c6..cf1999ea5fe13 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,6 +1,6 @@ import itertools from abc import abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type import torch import torch.nn.functional as F @@ -137,6 +137,36 @@ def apply(self, return F.linear(x, layer.weight, bias) +class TiedWeightLinearMethod(UnquantizedLinearMethod): + """Linear method base with noop create_weights + + Can be used to prevent the initialization of weights + during the initialization of modules with weight tying. + """ + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + ... + + +class QuantizationConfigOverride(QuantizationConfig): + """Config class to inject a specific LinearMethod. + """ + + def __init__(self, cls: Type[LinearMethodBase]): + self.cls = cls + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[LinearMethodBase]: + return self.cls() + + +QuantizationConfigOverride.__abstractmethods__ = frozenset() + + class LinearBase(torch.nn.Module): """Base linear layer. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index c64bc70688806..a5ae373c6d88d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -28,14 +28,19 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) + QuantizationConfigOverride, + RowParallelLinear, + TiedWeightLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -210,9 +215,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab - self.wte = VocabParallelEmbedding(self.vocab_size, - self.embed_dim, - org_num_embeddings=config.vocab_size) + self.wte = VocabParallelEmbedding( + self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, @@ -259,7 +270,7 @@ def forward( class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] + supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] embedding_modules = { "wte": "input_embeddings", @@ -280,16 +291,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.transformer = GPTBigCodeModel(vllm_config=vllm_config, prefix=prefix) - if self.config.tie_word_embeddings: - self.lm_head = self.transformer.wte - else: - self.lm_head = ParallelLMHead( - self.transformer.vocab_size, - self.transformer.embed_dim, - org_num_embeddings=self.config.vocab_size) + self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + if self.config.tie_word_embeddings: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=QuantizationConfigOverride( + TiedWeightLinearMethod), + params_dtype=self.transformer.wte.weight.dtype, + ) + self.lm_head.register_parameter("weight", + self.transformer.wte.weight) + else: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = get_sampler() @@ -335,7 +368,7 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "lm_head.weight" in name: + if "lm_head.weight" in name and self.config.tie_word_embeddings: continue if ".attn.bias" in name: # Skip attention mask.