Skip to content

Commit

Permalink
llm.generate fixes (#10983)
Browse files Browse the repository at this point in the history
* fix context path, disable optimizer init, add tp

Signed-off-by: HuiyingLi <[email protected]>

* format

Signed-off-by: HuiyingLi <[email protected]>

* address comments, require user to provide trainer

Signed-off-by: HuiyingLi <[email protected]>

* minor fix

Signed-off-by: HuiyingLi <[email protected]>

* minor fixes

Signed-off-by: HuiyingLi <[email protected]>

---------

Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi authored Oct 23, 2024
1 parent 9251d1c commit ed37d19
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def export_ckpt(
def generate(
path: Union[Path, str],
prompts: list[str],
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import nemo.lightning as nl
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

Expand Down Expand Up @@ -44,6 +45,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.
load_optim_state=False,
)
trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
trainer.strategy.connect(model)
if trainer.strategy.launcher is not None:
Expand All @@ -61,16 +63,22 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.

def setup_model_and_tokenizer(
path: Path,
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=path, subpath="model")
trainer = trainer or io.load_context(path=path, subpath="trainer")
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model.module.module.module
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
Expand Down

0 comments on commit ed37d19

Please sign in to comment.