Skip to content

Commit

Permalink
Gemma2 in Nemo2 with Recipes (NVIDIA#11037)
Browse files Browse the repository at this point in the history
* add gemma2 in nemo2.0 and 2b recipe

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

* Fix gemma1 inference bug

* add more recipe

* minor fix

* recipe fix

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

* merge fix

---------

Signed-off-by: suiyoubi <[email protected]>
Co-authored-by: suiyoubi <[email protected]>
  • Loading branch information
2 people authored and lilyw97 committed Nov 13, 2024
1 parent 0203809 commit 5eb19f0
Show file tree
Hide file tree
Showing 11 changed files with 1,217 additions and 5 deletions.
10 changes: 10 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
CodeLlamaConfig13B,
CodeLlamaConfig34B,
CodeLlamaConfig70B,
Gemma2Config,
Gemma2Config2B,
Gemma2Config9B,
Gemma2Config27B,
Gemma2Model,
GemmaConfig,
GemmaConfig2B,
GemmaConfig7B,
Expand Down Expand Up @@ -165,6 +170,11 @@
"CodeGemmaConfig2B",
"CodeGemmaConfig7B",
"GemmaModel",
"Gemma2Model",
"Gemma2Config9B",
"Gemma2Config",
"Gemma2Config27B",
"Gemma2Config2B",
"Baichuan2Config",
"Baichuan2Config7B",
"Baichuan2Model",
Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
GemmaConfig7B,
GemmaModel,
)
from nemo.collections.llm.gpt.model.gemma2 import (
Gemma2Config,
Gemma2Config2B,
Gemma2Config9B,
Gemma2Config27B,
Gemma2Model,
)
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM
from nemo.collections.llm.gpt.model.llama import (
CodeLlamaConfig7B,
Expand Down Expand Up @@ -142,6 +149,11 @@
"CodeGemmaConfig2B",
"CodeGemmaConfig7B",
"GemmaModel",
"Gemma2Config",
"Gemma2Config27B",
"Gemma2Config2B",
"Gemma2Config9B",
"Gemma2Model",
"LlamaModel",
"Baichuan2Config",
"Baichuan2Config7B",
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/llm/gpt/model/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Annotated, Callable, Optional

import torch
from megatron.core import parallel_state
from torch import nn

from nemo.collections.llm.fn.activation import openai_gelu
Expand Down Expand Up @@ -95,7 +96,8 @@ def configure_model(self):
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import EmbeddingScalingMixin

super().configure_model()
extend_instance(self.module.embedding, EmbeddingScalingMixin)
if parallel_state.is_pipeline_first_stage():
extend_instance(self.module.embedding, EmbeddingScalingMixin)


@io.model_importer(GemmaModel, "hf")
Expand Down Expand Up @@ -160,7 +162,7 @@ def make_vocab_size_divisible_by(vocab_size):
rotary_base=source.rope_theta,
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
share_embeddings_and_output_weights=True,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
Expand Down
Loading

0 comments on commit 5eb19f0

Please sign in to comment.