Skip to content

Commit

Permalink
address comments, require user to provide trainer
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Oct 22, 2024
1 parent a4cb8c6 commit 795e7d6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 0 additions & 2 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ def generate(
path: Union[Path, str],
prompts: list[str],
trainer: Optional[nl.Trainer] = None,
tensor_parallel_size: int = -1,
params_dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
Expand All @@ -450,7 +449,6 @@ def generate(
inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
path=path,
trainer=trainer,
tensor_parallel_size=tensor_parallel_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
)
Expand Down
20 changes: 11 additions & 9 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,24 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.

def setup_model_and_tokenizer(
path: Path,
trainer: Optional[nl.Trainer] = None,
tensor_parallel_size: int = -1,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
trainer = trainer or io.load_context(path=ckpt_to_context_subdir(path), subpath="trainer")
if tensor_parallel_size > 0:
trainer.strategy.tensor_model_parallel_size = tensor_parallel_size
trainer.devices = tensor_parallel_size
trainer.strategy.parallel_devices = [torch.device(f"cuda:{i}") for i in range(tensor_parallel_size)]
trainer.strategy.launcher.num_processes = len(trainer.strategy.parallel_devices)
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
model,
mcore_model,
InferenceWrapperConfig(
hidden_size=model.config.hidden_size,
params_dtype=params_dtype,
Expand Down

0 comments on commit 795e7d6

Please sign in to comment.