Skip to content

Commit

Permalink
Fix pip install (NVIDIA#11026)
Browse files Browse the repository at this point in the history
* Move AutoTokenizer inline

Signed-off-by: Marc Romeyn <[email protected]>

* Move einops to common requirements

Signed-off-by: Marc Romeyn <[email protected]>

* Move AutoTokenizer import to top-level again in fine_tuning

Signed-off-by: Marc Romeyn <[email protected]>

* Move megatron init inside nemo.lightning

Signed-off-by: Marc Romeyn <[email protected]>

* Make megatron_lazy_init_context work when transformer-engine is not installed

Signed-off-by: Marc Romeyn <[email protected]>

* Only import get_nmt_tokenizer when needed

Signed-off-by: Marc Romeyn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

---------

Signed-off-by: Marc Romeyn <[email protected]>
Signed-off-by: marcromeyn <[email protected]>
Co-authored-by: marcromeyn <[email protected]>
  • Loading branch information
2 people authored and HuiyingLi committed Nov 15, 2024
1 parent 58cdeb9 commit 9520062
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 13 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/llm/gpt/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def __init__(
self.persistent_workers = persistent_workers
self.create_attention_mask = create_attention_mask or not HAVE_TE

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
if tokenizer is None:
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
else:
self.tokenizer = tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
Expand Down
30 changes: 20 additions & 10 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch
from torch import nn

from nemo.lightning.megatron_init import initialize_model_parallel_for_nemo

NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE"


Expand Down Expand Up @@ -57,7 +59,6 @@ def init_parallel_ranks(
seed (int, optional): The seed for random number generation. Defaults to 1234.
fp8 (bool, optional): Whether to use fp8 precision for model parameters. Defaults to False.
"""
from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo
from nemo.utils import AppState

app_state = AppState()
Expand Down Expand Up @@ -169,17 +170,20 @@ def set_model_parallel_attributes(model, parallelism):

@contextmanager
def megatron_lazy_init_context(config) -> Generator[None, None, None]:
from megatron.core.extensions import transformer_engine as _te
try:
from megatron.core.extensions import transformer_engine as _te

original = _te._get_extra_te_kwargs # noqa: SLF001
original = _te._get_extra_te_kwargs # noqa: SLF001

def _get_extra_te_kwargs_meta(c):
"""Forces device to meta"""
kwargs = original(c)
kwargs['device'] = 'meta'
return kwargs
def _get_extra_te_kwargs_meta(c):
"""Forces device to meta"""
kwargs = original(c)
kwargs['device'] = 'meta'
return kwargs

_te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001
_te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001
except ImportError:
pass

_orig_perform_initialization = config.perform_initialization
_orig_use_cpu_initialization = config.use_cpu_initialization
Expand All @@ -189,7 +193,13 @@ def _get_extra_te_kwargs_meta(c):

yield

_te._get_extra_te_kwargs = original # noqa: SLF001
try:
from megatron.core.extensions import transformer_engine as _te

_te._get_extra_te_kwargs = original # noqa: SLF001
except ImportError:
pass

config.perform_initialization = _orig_perform_initialization
config.use_cpu_initialization = _orig_use_cpu_initialization

Expand Down
Loading

0 comments on commit 9520062

Please sign in to comment.