diff --git a/mlc_llm/core.py b/mlc_llm/core.py index c56047583b..29e33fd6fc 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -644,11 +644,18 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() + num_query_heads = config.num_attention_heads // args.num_shards + hidden_size = config.hidden_size // args.num_shards + if hasattr(config, "head_dim"): + head_dim = config.head_dim + else: + head_dim = hidden_size // num_query_heads # pylint: disable=no-value-for-parameter mod = fuse_split_rotary_embedding( - config.num_attention_heads // args.num_shards, + num_query_heads, num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, + hidden_size, + head_dim, config.position_embedding_base, batched=args.enable_batching, )(mod) @@ -892,6 +899,7 @@ def build_model_from_args(args: argparse.Namespace): model_generators["llama"] = llama_batched_vllm model_generators["mistral"] = llama_batched_vllm model_generators["mixtral"] = llama_batched_vllm + model_generators["gemma"] = llama_batched_vllm assert args.model_category in model_generators, f"Model {args.model} not supported" diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 5a085c5e03..957ff192c3 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -17,6 +17,8 @@ @dataclass class LlamaConfig: + rms_norm_weight_offset = 0.0 + def __init__( self, dtype="float32", @@ -96,6 +98,19 @@ def __init__( self.quantization_scheme = kwargs["quantization_scheme"] +class GemmaConfig(LlamaConfig): + rms_norm_weight_offset = 1.0 + + head_dim: int + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.head_dim = kwargs["head_dim"] + + class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): self.in_features = in_features @@ -133,9 +148,10 @@ def forward(self, x: relax.Expr) -> relax.Var: class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, dtype, eps=1e-6): + def __init__(self, hidden_size, dtype, eps=1e-6, weight_offset=0.0): self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") self.variance_epsilon = tvm.tir.const(eps, dtype) + self.weight_offset = weight_offset def forward(self, hidden_states): from tvm import te, tir @@ -173,9 +189,12 @@ def f_div_cast_3d(bsz, i, k): name=x.op.name + "red_temp", ) - return te.compute( + output = te.compute( x.shape, - lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)), + lambda i, k: f_mul_cast( + weight(k), + f_div_cast_2d(i, k), + ), name="rms_norm", ) else: @@ -185,13 +204,41 @@ def f_div_cast_3d(bsz, i, k): name=x.op.name + "red_temp", ) - return te.compute( + output = te.compute( x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)), + lambda bsz, i, k: f_mul_cast( + weight(k), + f_div_cast_3d(bsz, i, k), + ), name="rms_norm", ) - return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") + return output + + # Currently, the cutlass.rms_norm assumes that + # `cutlass::rmsnorm` can be used in place of any PrimFunc that + # is named `rms_norm`. As a result, non-zero `weight_offset` + # applied inside the TE kernel definition would produce + # incorrect results. Applying the `weight_offset` outside the + # `nn.emit_te` is required for correct results. (It's also + # preferable for performance, so that the `weight_offset` can + # be preprocessed.) + # + # TODO(Lunderberg): Change the "cutlass.rms_norm" pattern to + # verify the function that it calls. + if self.weight_offset == 0: + rms_weights = self.weight + else: + rms_weights = nn.emit( + self.weight + R.const(self.weight_offset, dtype=self.weight.struct_info.dtype), + name_hint="rms_weights", + ) + return nn.emit_te( + f_rms_norm, + hidden_states, + rms_weights, + primfunc_name_hint="rms_norm", + ) class LlamaMLP(nn.Module): @@ -213,6 +260,8 @@ def __init__(self, config: LlamaConfig): self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.act = {"silu": relax.op.nn.silu, "gelu": relax.op.nn.gelu}[config.hidden_act] + def forward(self, x): if self.combine_matmul: gate_up_results = nn.emit( @@ -228,7 +277,7 @@ def forward(self, x): gate_result = self.gate_proj(x) up_result = self.up_proj(x) - result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + result = self.down_proj(self.act(gate_result) * up_result) return result @@ -280,7 +329,12 @@ def __init__(self, config: LlamaConfig): self.hidden_size = config.hidden_size self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards self.num_query_heads = config.num_attention_heads // self.num_shards - self.head_dim = self.hidden_size // config.num_attention_heads + + if hasattr(config, "head_dim"): + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + self.position_embedding_base = config.position_embedding_base self.combine_matmul = config.combine_matmul @@ -322,7 +376,10 @@ def __init__(self, config: LlamaConfig): self.v_proj.weight.shard_dim = 0 self.o_proj = Linear( - self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias + self.head_dim * self.num_query_heads, + self.hidden_size, + dtype=dtype, + bias=config.attention_bias, ) self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" @@ -598,10 +655,16 @@ def __init__(self, config: LlamaConfig, enable_batching: bool): self.use_moe = False self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + config.hidden_size, + dtype=config.dtype, + eps=config.rms_norm_eps, + weight_offset=config.rms_norm_weight_offset, ) self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + config.hidden_size, + dtype=config.dtype, + eps=config.rms_norm_eps, + weight_offset=config.rms_norm_weight_offset, ) def post_self_attn(self, hidden_states, residual): @@ -1331,6 +1394,11 @@ def quantize(experts, relax_pname): assert relax_pname.endswith("scales") return qscale + if hasattr(config, "head_dim"): + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # Expected to enter this function only for the combined linear matmul weights. # Other weights are supposed to be loaded in `f_convert_param_bkwd` since @@ -1365,8 +1433,8 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): "Matmul combination is not turned on, and the function " "is not expected to be entered" ) + hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads if "query_key_value_proj" in relax_pname: q_heads = config.num_attention_heads @@ -1401,7 +1469,6 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): device = tvm.cpu() param_list = [None] * param_manager.nparam_to_load - head_dim = config.hidden_size / config.num_attention_heads inv_freq = 1.0 / ( config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) ) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index fb5cc79f18..da6217a70b 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -19,6 +19,7 @@ from .llama import ( LlamaConfig, MixtralConfig, + GemmaConfig, Linear, Embedding, LlamaRMSNorm, @@ -492,6 +493,7 @@ def __init__( kv_type: KVCacheType, sep_embed: bool = False, ): + self.config = config self.padding_idx = config.pad_token_id self.embed_tokens = None @@ -501,7 +503,12 @@ def __init__( self.layers = ModuleList( [LlamaDecoderLayerBatched(config, kv_type) for _ in range(config.num_hidden_layers)] ) - self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + self.norm = LlamaRMSNorm( + config.hidden_size, + dtype=config.dtype, + eps=config.rms_norm_eps, + weight_offset=config.rms_norm_weight_offset, + ) def forward( self, @@ -519,6 +526,9 @@ def forward( hidden_states = inputs_embeds + if isinstance(self.config, GemmaConfig): + hidden_states = nn.emit(hidden_states * relax.const(self.config.hidden_size**0.5, dtype="float16")) + new_kvs = () for idx, decoder_layer in enumerate(self.layers): @@ -551,11 +561,19 @@ def __init__( self.num_shards = config.num_shards self.cpu_device = cpu_device self.model = LlamaModel(config, vocab_size_var, kv_type, sep_embed) - self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + if isinstance(config, GemmaConfig): + assert self.model.embed_tokens is not None + self.lm_head = lambda hidden: nn.emit(relax.op.linear(hidden, self.model.embed_tokens.weight)) + else: + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ - assert config.hidden_size % config.num_attention_heads == 0 - head_dim = config.hidden_size // config.num_attention_heads + if hasattr(config, "head_dim"): + head_dim = config.head_dim + else: + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads # Set the cached sin/cos to the maximum of 2048 and max seq len. # This will be eliminated further with online rotary embedding calculation. @@ -703,7 +721,11 @@ def get_inputs( num_blocks = tvm.tir.Var("num_blocks", "int64") num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - head_size = hidden_size // config.num_attention_heads + + if hasattr(config, "head_dim"): + head_size = config.head_dim + else: + head_size = config.hidden_size // config.num_attention_heads if kv_type == KVCacheType.VLLM: block_size = VllmAttention.block_size @@ -1042,6 +1064,16 @@ def get_model(args, hf_config): build_model_only=args.build_model_only, quantization_scheme=args.quantization, ) + elif "gemma" in args.model.lower(): + config = GemmaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) elif "max_sequence_length" in hf_config: config = LlamaConfig( **hf_config, diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index 36fa7f5fa6..3f707689b4 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -461,10 +461,8 @@ def apply_rewrite(mod, split_rotary, get_pattern_func): def fuse_split_rotary_embedding( - num_query_heads, num_kv_heads, hidden_size, position_embedding_base, batched=False + num_query_heads, num_kv_heads, hidden_size, head_dim, position_embedding_base, batched=False ): - head_dim = hidden_size // num_query_heads - @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: split_rotary = get_dynamic_split_rotary() diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index aa705f3e62..6ebd930fba 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -27,6 +27,7 @@ "mistral", "mixtral", "stablelm_epoch", + "gemma", ] ) diff --git a/serve/mlc_serve/model/base.py b/serve/mlc_serve/model/base.py index 63daf7895c..06a189e38e 100644 --- a/serve/mlc_serve/model/base.py +++ b/serve/mlc_serve/model/base.py @@ -25,6 +25,7 @@ class ModelArtifactConfig: num_attention_heads: Optional[int] = None num_hidden_layers: Optional[int] = None hidden_size: Optional[int] = None + head_dim: Optional[int] = None @classmethod def _from_json(config_cls, json_obj: dict): @@ -40,14 +41,16 @@ def _from_json(config_cls, json_obj: dict): class AssetNotFound(Exception): def __init__(self, asset_path): self.asset_path = asset_path - super().__init__(f"{self.asset_path} should exist. Did you build with `--enable-batching`?") + super().__init__( + f"{self.asset_path} should exist. Did you build with `--enable-batching`?" + ) def get_model_artifact_config(model_artifact_path): json_object = {"model_artifact_path": model_artifact_path} for config_file_name in [ "build_config.json", - "model/mlc-model-config.json" + "model/mlc-model-config.json", ]: config_file_path = os.path.join(model_artifact_path, config_file_name) if not os.path.exists(config_file_path): @@ -59,7 +62,12 @@ def get_model_artifact_config(model_artifact_path): if not "paged_kv_cache_type" in json_object: json_object["paged_kv_cache_type"] = "vllm" - return ModelArtifactConfig._from_json(json_object) + config = ModelArtifactConfig._from_json(json_object) + + if config.head_dim is None: + config.head_dim = config.hidden_size // config.num_attention_heads + + return config def get_hf_config(model_path: Path) -> AutoConfig: diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 67fa7eefbd..baababe368 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -162,7 +162,7 @@ def __init__( self.num_shards = config.num_shards # TODO(@sunggg): Find a better way - if config.model_type in ["llama", "mistral", "mixtral"]: + if config.model_type in ["llama", "mistral", "mixtral", "gemma"]: self.torch_dtype = torch.float32 else: assert 0, f"{config.model_type} is NOT supported yet" @@ -253,7 +253,9 @@ def profile_memory_usage(self, seq_lens): vm_alloc_after = self.get_used_memory() - LOG.info(f"peak memory during profling: {(vm_alloc_after - vm_alloc_before) / 1e9} GB") + LOG.info( + f"peak memory during profling: {(vm_alloc_after - vm_alloc_before) / 1e9} GB" + ) return self.get_param_nbytes() + (vm_alloc_after - vm_alloc_before) @@ -561,9 +563,7 @@ def init_tvm_model( num_kv_heads = ( model_artifact_config.num_key_value_heads // model_artifact_config.num_shards ) - head_size = ( - model_artifact_config.hidden_size // model_artifact_config.num_attention_heads - ) + head_size = model_artifact_config.head_dim if model_artifact_config.paged_kv_cache_type == "flash-decoding": allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 0cb313da9c..a6cd5dff6e 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -21,10 +21,12 @@ def _test(args: argparse.Namespace): sampling_params_greedy = SamplingParams( temperature=0.0, + vocab_size=engine.model_artifact_config.vocab_size, ) sampling_params_random = SamplingParams( temperature=1.0, top_p=1.0, + vocab_size=engine.model_artifact_config.vocab_size, ) num_sequences = args.num_sequences_to_sample