Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm.generate fixes #10983

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def export_ckpt(
def generate(
path: Union[Path, str],
prompts: list[str],
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import nemo.lightning as nl
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

Expand Down Expand Up @@ -44,6 +45,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.
load_optim_state=False,
)
trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
trainer.strategy.connect(model)
if trainer.strategy.launcher is not None:
Expand All @@ -61,16 +63,22 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.

def setup_model_and_tokenizer(
path: Path,
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=path, subpath="model")
trainer = trainer or io.load_context(path=path, subpath="trainer")
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that top level model is also McoreGPTModel and can directly be pass in and generate without problem, but I might be missing something. Do you have suggestions? I can add it back with the logic of finding the correct McoreGPTModel without subclasssing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the class expects an Mcore object so its better to pass that to be compatible with Mcore's API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I will add the logic for finding the exact McoreGPTModel class thanks for the suggestion!

mcore_model = model.module.module.module
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
Expand Down
Loading