Skip to content

Commit

Permalink
Add llama 3.2 1b and 3b (#11335)
Browse files Browse the repository at this point in the history
* add llama 3.2 1b and 3b

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* add recipe to init

Signed-off-by: Chen Cui <[email protected]>

* fix path

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
cuichenx and cuichenx authored Nov 22, 2024
1 parent a153b8c commit d033737
Show file tree
Hide file tree
Showing 6 changed files with 588 additions and 3 deletions.
4 changes: 4 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
Llama31Config8B,
Llama31Config70B,
Llama31Config405B,
Llama32Config1B,
Llama32Config3B,
LlamaConfig,
LlamaModel,
MaskedTokenLossReduction,
Expand Down Expand Up @@ -171,6 +173,8 @@
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"Llama32Config1B",
"Llama32Config3B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
Llama31Config8B,
Llama31Config70B,
Llama31Config405B,
Llama32Config1B,
Llama32Config3B,
LlamaConfig,
LlamaModel,
)
Expand Down Expand Up @@ -134,6 +136,8 @@
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"Llama32Config1B",
"Llama32Config3B",
"NemotronConfig",
"Nemotron3Config4B",
"Nemotron3Config8B",
Expand Down
39 changes: 36 additions & 3 deletions nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import math
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional

Expand Down Expand Up @@ -86,7 +87,7 @@ class Llama2Config70B(LlamaConfig):


@dataclass
class Llama3Config(GPTConfig):
class Llama3Config(LlamaConfig):
num_query_groups: int = 8
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
Expand Down Expand Up @@ -182,6 +183,32 @@ class Llama31Config405B(Llama31Config):
make_vocab_size_divisible_by: int = 128


@dataclass
class Llama32Config1B(Llama31Config):
scale_factor: int = 32
share_embeddings_and_output_weights: bool = True
rotary_base: int = 500_000
num_layers: int = 16
hidden_size: int = 2048
ffn_hidden_size: int = 8192
num_attention_heads: int = 32
num_query_groups: int = 8
make_vocab_size_divisible_by: int = 128


@dataclass
class Llama32Config3B(Llama31Config):
scale_factor: int = 32
share_embeddings_and_output_weights: bool = True
rotary_base: int = 500_000
num_layers: int = 28
hidden_size: int = 3072
ffn_hidden_size: int = 8192
num_attention_heads: int = 24
num_query_groups: int = 8
make_vocab_size_divisible_by: int = 128


@dataclass
class CodeLlamaConfig7B(Llama2Config7B):
rotary_base: int = 1_000_000
Expand Down Expand Up @@ -252,6 +279,9 @@ def convert_state(self, source, target):
"model.norm.weight": "decoder.final_layernorm.weight",
"lm_head.weight": "output_layer.weight",
}
if getattr(source.config, "tie_word_embeddings", False):
# llama 3.2 1B and 3B models have no shared input output embeddings
del mapping["lm_head.weight"]

return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1])

Expand All @@ -275,7 +305,7 @@ def make_vocab_size_divisible_by(vocab_size):

if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3':
# Apply Llama3.1 customize rope scaling
cls = Llama31Config
cls = partial(Llama31Config, scale_factor=source.rope_scaling.get("factor", 8.0))
else:
cls = LlamaConfig
output = cls(
Expand All @@ -289,7 +319,7 @@ def make_vocab_size_divisible_by(vocab_size):
rotary_base=source.rope_theta,
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
share_embeddings_and_output_weights=getattr(source, "tie_word_embeddings", False),
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
Expand Down Expand Up @@ -355,6 +385,7 @@ def config(self) -> "HFLlamaConfig":
num_key_value_heads=source.num_query_groups,
rope_theta=source.rotary_base,
vocab_size=self.tokenizer.vocab_size,
tie_word_embeddings=source.share_embeddings_and_output_weights,
)


Expand Down Expand Up @@ -509,6 +540,8 @@ def apply_rope_scaling(
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"Llama32Config1B",
"Llama32Config3B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
llama31_8b,
llama31_70b,
llama31_405b,
llama32_1b,
llama32_3b,
mamba2_1_3b,
mamba2_2_7b,
mamba2_8b,
Expand Down Expand Up @@ -89,6 +91,8 @@
"llama31_8b",
"llama31_70b",
"llama31_405b",
"llama32_1b",
"llama32_3b",
"mamba2_130m",
"mamba2_370m",
"mamba2_780m",
Expand Down
Loading

0 comments on commit d033737

Please sign in to comment.