Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Models] Implement support for gemma #219

Merged
merged 10 commits into from
Feb 22, 2024
12 changes: 10 additions & 2 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,18 @@ def mod_transform_before_build(

if max_seq_len:
num_key_value_heads = config.get_num_key_value_heads()
num_query_heads = config.num_attention_heads // args.num_shards
hidden_size = config.hidden_size // args.num_shards
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = hidden_size // num_query_heads
# pylint: disable=no-value-for-parameter
mod = fuse_split_rotary_embedding(
config.num_attention_heads // args.num_shards,
num_query_heads,
num_key_value_heads // args.num_shards,
config.hidden_size // args.num_shards,
hidden_size,
head_dim,
config.position_embedding_base,
batched=args.enable_batching,
)(mod)
Expand Down Expand Up @@ -892,6 +899,7 @@ def build_model_from_args(args: argparse.Namespace):
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm
model_generators["gemma"] = llama_batched_vllm

assert args.model_category in model_generators, f"Model {args.model} not supported"

Expand Down
93 changes: 80 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

@dataclass
class LlamaConfig:
rms_norm_weight_offset = 0.0

def __init__(
self,
dtype="float32",
Expand Down Expand Up @@ -96,6 +98,19 @@ def __init__(
self.quantization_scheme = kwargs["quantization_scheme"]


class GemmaConfig(LlamaConfig):
rms_norm_weight_offset = 1.0

head_dim: int

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.head_dim = kwargs["head_dim"]


class Linear(nn.Module):
def __init__(self, in_features, out_features, dtype: str, bias=True):
self.in_features = in_features
Expand Down Expand Up @@ -133,9 +148,10 @@ def forward(self, x: relax.Expr) -> relax.Var:


class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, dtype, eps=1e-6):
def __init__(self, hidden_size, dtype, eps=1e-6, weight_offset=0.0):
self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight")
self.variance_epsilon = tvm.tir.const(eps, dtype)
self.weight_offset = weight_offset

def forward(self, hidden_states):
from tvm import te, tir
Expand Down Expand Up @@ -173,9 +189,12 @@ def f_div_cast_3d(bsz, i, k):
name=x.op.name + "red_temp",
)

return te.compute(
output = te.compute(
x.shape,
lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)),
lambda i, k: f_mul_cast(
weight(k),
f_div_cast_2d(i, k),
),
name="rms_norm",
)
else:
Expand All @@ -185,13 +204,41 @@ def f_div_cast_3d(bsz, i, k):
name=x.op.name + "red_temp",
)

return te.compute(
output = te.compute(
x.shape,
lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)),
lambda bsz, i, k: f_mul_cast(
weight(k),
f_div_cast_3d(bsz, i, k),
),
name="rms_norm",
)

return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm")
return output

# Currently, the cutlass.rms_norm assumes that
# `cutlass::rmsnorm` can be used in place of any PrimFunc that
# is named `rms_norm`. As a result, non-zero `weight_offset`
# applied inside the TE kernel definition would produce
# incorrect results. Applying the `weight_offset` outside the
# `nn.emit_te` is required for correct results. (It's also
# preferable for performance, so that the `weight_offset` can
# be preprocessed.)
#
# TODO(Lunderberg): Change the "cutlass.rms_norm" pattern to
# verify the function that it calls.
if self.weight_offset == 0:
rms_weights = self.weight
else:
rms_weights = nn.emit(
self.weight + R.const(self.weight_offset, dtype=self.weight.struct_info.dtype),
name_hint="rms_weights",
)
return nn.emit_te(
f_rms_norm,
hidden_states,
rms_weights,
primfunc_name_hint="rms_norm",
)


class LlamaMLP(nn.Module):
Expand All @@ -213,6 +260,8 @@ def __init__(self, config: LlamaConfig):
self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False)
self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)

self.act = {"silu": relax.op.nn.silu, "gelu": relax.op.nn.gelu}[config.hidden_act]

def forward(self, x):
if self.combine_matmul:
gate_up_results = nn.emit(
Expand All @@ -228,7 +277,7 @@ def forward(self, x):
gate_result = self.gate_proj(x)
up_result = self.up_proj(x)

result = self.down_proj(relax.op.nn.silu(gate_result) * up_result)
result = self.down_proj(self.act(gate_result) * up_result)
return result


Expand Down Expand Up @@ -280,7 +329,12 @@ def __init__(self, config: LlamaConfig):
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
self.num_query_heads = config.num_attention_heads // self.num_shards
self.head_dim = self.hidden_size // config.num_attention_heads

if hasattr(config, "head_dim"):
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // config.num_attention_heads

self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul
Expand Down Expand Up @@ -322,7 +376,10 @@ def __init__(self, config: LlamaConfig):
self.v_proj.weight.shard_dim = 0

self.o_proj = Linear(
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias
self.head_dim * self.num_query_heads,
self.hidden_size,
dtype=dtype,
bias=config.attention_bias,
)
self.o_proj.weight.shard_dim = 1
self.o_proj.weight.shard_strategy = "shard_o_proj_k"
Expand Down Expand Up @@ -598,10 +655,16 @@ def __init__(self, config: LlamaConfig, enable_batching: bool):
self.use_moe = False
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)

def post_self_attn(self, hidden_states, residual):
Expand Down Expand Up @@ -1331,6 +1394,11 @@ def quantize(experts, relax_pname):
assert relax_pname.endswith("scales")
return qscale

if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads

def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
# Expected to enter this function only for the combined linear matmul weights.
# Other weights are supposed to be loaded in `f_convert_param_bkwd` since
Expand Down Expand Up @@ -1365,8 +1433,8 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
"Matmul combination is not turned on, and the function "
"is not expected to be entered"
)

hidden_size = config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads

if "query_key_value_proj" in relax_pname:
q_heads = config.num_attention_heads
Expand Down Expand Up @@ -1401,7 +1469,6 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
device = tvm.cpu()
param_list = [None] * param_manager.nparam_to_load

head_dim = config.hidden_size / config.num_attention_heads
inv_freq = 1.0 / (
config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)
)
Expand Down
42 changes: 37 additions & 5 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .llama import (
LlamaConfig,
MixtralConfig,
GemmaConfig,
Linear,
Embedding,
LlamaRMSNorm,
Expand Down Expand Up @@ -492,6 +493,7 @@ def __init__(
kv_type: KVCacheType,
sep_embed: bool = False,
):
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = None

Expand All @@ -501,7 +503,12 @@ def __init__(
self.layers = ModuleList(
[LlamaDecoderLayerBatched(config, kv_type) for _ in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps)
self.norm = LlamaRMSNorm(
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)

def forward(
self,
Expand All @@ -519,6 +526,9 @@ def forward(

hidden_states = inputs_embeds

if isinstance(self.config, GemmaConfig):
hidden_states = nn.emit(hidden_states * relax.const(self.config.hidden_size**0.5, dtype="float16"))

new_kvs = ()

for idx, decoder_layer in enumerate(self.layers):
Expand Down Expand Up @@ -551,11 +561,19 @@ def __init__(
self.num_shards = config.num_shards
self.cpu_device = cpu_device
self.model = LlamaModel(config, vocab_size_var, kv_type, sep_embed)
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

if isinstance(config, GemmaConfig):
assert self.model.embed_tokens is not None
self.lm_head = lambda hidden: nn.emit(relax.op.linear(hidden, self.model.embed_tokens.weight))
else:
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

############ Rotary embedding constants ############
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
Expand Down Expand Up @@ -703,7 +721,11 @@ def get_inputs(
num_blocks = tvm.tir.Var("num_blocks", "int64")

num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
head_size = hidden_size // config.num_attention_heads

if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads

if kv_type == KVCacheType.VLLM:
block_size = VllmAttention.block_size
Expand Down Expand Up @@ -1042,6 +1064,16 @@ def get_model(args, hf_config):
build_model_only=args.build_model_only,
quantization_scheme=args.quantization,
)
elif "gemma" in args.model.lower():
config = GemmaConfig(
**hf_config,
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
elif "max_sequence_length" in hf_config:
config = LlamaConfig(
**hf_config,
Expand Down
4 changes: 1 addition & 3 deletions mlc_llm/transform/fuse_split_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,8 @@ def apply_rewrite(mod, split_rotary, get_pattern_func):


def fuse_split_rotary_embedding(
num_query_heads, num_kv_heads, hidden_size, position_embedding_base, batched=False
num_query_heads, num_kv_heads, hidden_size, head_dim, position_embedding_base, batched=False
):
head_dim = hidden_size // num_query_heads

@tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding")
def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule:
split_rotary = get_dynamic_split_rotary()
Expand Down
1 change: 1 addition & 0 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"mistral",
"mixtral",
"stablelm_epoch",
"gemma",
]
)

Expand Down
14 changes: 11 additions & 3 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ModelArtifactConfig:
num_attention_heads: Optional[int] = None
num_hidden_layers: Optional[int] = None
hidden_size: Optional[int] = None
head_dim: Optional[int] = None

@classmethod
def _from_json(config_cls, json_obj: dict):
Expand All @@ -40,14 +41,16 @@ def _from_json(config_cls, json_obj: dict):
class AssetNotFound(Exception):
def __init__(self, asset_path):
self.asset_path = asset_path
super().__init__(f"{self.asset_path} should exist. Did you build with `--enable-batching`?")
super().__init__(
f"{self.asset_path} should exist. Did you build with `--enable-batching`?"
)


def get_model_artifact_config(model_artifact_path):
json_object = {"model_artifact_path": model_artifact_path}
for config_file_name in [
"build_config.json",
"model/mlc-model-config.json"
"model/mlc-model-config.json",
]:
config_file_path = os.path.join(model_artifact_path, config_file_name)
if not os.path.exists(config_file_path):
Expand All @@ -59,7 +62,12 @@ def get_model_artifact_config(model_artifact_path):
if not "paged_kv_cache_type" in json_object:
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)
config = ModelArtifactConfig._from_json(json_object)

if config.head_dim is None:
config.head_dim = config.hidden_size // config.num_attention_heads

return config


def get_hf_config(model_path: Path) -> AutoConfig:
Expand Down
Loading
Loading