diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index e7e660060f54..aaef714ef738 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -20,7 +20,9 @@ import lightning.pytorch as pl import nemo_run as run import torch +from megatron.core import parallel_state from rich.console import Console +from torch.distributed import all_gather_object from typing_extensions import Annotated import nemo.lightning as nl @@ -36,6 +38,8 @@ from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + if TYPE_CHECKING: from megatron.core.inference.common_inference_params import CommonInferenceParams @@ -662,9 +666,10 @@ def export_ckpt( @run.cli.entrypoint(name="generate", namespace="llm") def generate( path: Union[Path, str], - prompts: list[str], trainer: nl.Trainer, + prompts: Optional[list[str]] = None, encoder_prompts: Optional[list[str]] = None, + input_dataset: Optional[Union[pl.LightningDataModule, str]] = None, params_dtype: torch.dtype = torch.bfloat16, add_BOS: bool = False, max_batch_size: int = 4, @@ -672,6 +677,7 @@ def generate( inference_batch_times_seqlen_threshold: int = 1000, inference_params: Optional["CommonInferenceParams"] = None, text_only: bool = False, + output_path: Optional[Union[Path, str]] = None, ) -> list[Union["InferenceRequest", str]]: """ Generates text using a NeMo LLM model. @@ -725,6 +731,8 @@ def generate( prompts (list[str]): The list of prompts to generate text for. trainer (nl.Trainer): The trainer object. encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None. + input_dataset (Optional[Union[pl.LightningDataModule, str]], optional): The input data module or jsonl file. + Test set will be used for generation for data modules. Defaults to None. params_dtype (torch.dtype, optional): The data type of the model parameters. Defaults to torch.bfloat16. add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False. max_batch_size (int, optional): The maximum batch size. Defaults to 4. @@ -734,6 +742,8 @@ def generate( inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in Mcore's CommonInferenceParams. Defaults to None. text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False. + output_path (Optional[Union[Path, str]], optional): The path to save the generated text or test dataset + predictions. Defaults to None. Returns: list[Union["InferenceRequest", str]]: A list of generated text, @@ -741,24 +751,63 @@ def generate( """ from nemo.collections.llm import inference + if input_dataset is not None: + input_path = input_dataset if isinstance(input_dataset, str) else input_dataset.test_path + with open(input_path) as f: + dataset = [json.loads(sample) for sample in f.readlines()] + inputs = [sample["input"] for sample in dataset] + elif prompts is not None: + inputs = prompts + else: + raise ValueError("Either prompts or input_dataset must be provided.") + inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer( path=path, trainer=trainer, params_dtype=params_dtype, inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, ) - results = inference.generate( + + dp_size = trainer.strategy.distributed_sampler_kwargs['num_replicas'] + dp_rank = trainer.strategy.distributed_sampler_kwargs['rank'] + chunk_size = (len(inputs) + dp_size - 1) // dp_size + start_idx = dp_rank * chunk_size + end_idx = min(start_idx + chunk_size, len(inputs)) + inputs_on_this_dp_rank = inputs[start_idx:end_idx] + + results_on_this_dp_rank = inference.generate( model=inference_wrapped_model, tokenizer=mcore_tokenizer, - prompts=prompts, + prompts=inputs_on_this_dp_rank, encoder_prompts=encoder_prompts, add_BOS=add_BOS, max_batch_size=max_batch_size, random_seed=random_seed, inference_params=inference_params, ) + gathered_results = [None] * dp_size - return [r.generated_text if text_only else r for r in results] + all_gather_object( + gathered_results, + [r.generated_text if text_only else r for r in results_on_this_dp_rank], + group=parallel_state.get_data_parallel_group(), + ) + gathered_results = [result for sublist in gathered_results for result in sublist] + + assert len(gathered_results) == len(inputs) + + if output_path is not None and is_global_rank_zero(): + with open(output_path, "w") as f: + for sample, pred in zip(dataset if input_dataset else inputs, gathered_results): + if type(sample) == dict: + sample["label"] = sample.pop("output", None) + sample["prediction"] = pred if text_only else pred.generated_text + elif type(sample) == str: + sample = {"input": sample, "prediction": pred if text_only else pred.generated_text} + f.write(json.dumps(sample) + "\n") + logging.info(f"Predictions written to {output_path}") + + return gathered_results def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None: diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 8a3cbc925dad..085d686afa4e 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -242,7 +242,7 @@ def generate( text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed ) - common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512) + common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512, top_k=1) results = mcore_engine.generate( prompts=prompts,