diff --git a/nemo/collections/llm/gpt/model/gemma2.py b/nemo/collections/llm/gpt/model/gemma2.py index eb3e16b5b60c..cae47a8665e8 100644 --- a/nemo/collections/llm/gpt/model/gemma2.py +++ b/nemo/collections/llm/gpt/model/gemma2.py @@ -24,7 +24,6 @@ from nemo.collections.llm.fn.activation import openai_gelu from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config -from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_spec import get_gemma2_layer_spec from nemo.lightning import OptimizerModule, io, teardown from nemo.lightning.pytorch.utils import dtype_from_hf @@ -36,6 +35,8 @@ def gemma2_layer_spec(config: "GPTConfig") -> ModuleSpec: + from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_spec import get_gemma2_layer_spec + return get_gemma2_layer_spec()