Skip to content

Commit

Permalink
Nemo2 batcheval (#11158)
Browse files Browse the repository at this point in the history
* initial draft for eval api

Signed-off-by: HuiyingLi <[email protected]>

* add dp to generate

Signed-off-by: HuiyingLi <[email protected]>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <[email protected]>

* add top_k=1 to defaul inf param to get deterministic output

Signed-off-by: HuiyingLi <[email protected]>

* change name

Signed-off-by: HuiyingLi <[email protected]>

* add eval ds and write to file to llm.generate

Signed-off-by: HuiyingLi <[email protected]>

* support standalone input jsonl

Signed-off-by: HuiyingLi <[email protected]>

---------

Signed-off-by: HuiyingLi <[email protected]>
Signed-off-by: HuiyingLi <[email protected]>
Co-authored-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi and HuiyingLi authored Nov 19, 2024
1 parent 425a4dd commit b817158
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
57 changes: 53 additions & 4 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import lightning.pytorch as pl
import nemo_run as run
import torch
from megatron.core import parallel_state
from rich.console import Console
from torch.distributed import all_gather_object
from typing_extensions import Annotated

import nemo.lightning as nl
Expand All @@ -36,6 +38,8 @@
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero


if TYPE_CHECKING:
from megatron.core.inference.common_inference_params import CommonInferenceParams
Expand Down Expand Up @@ -662,16 +666,18 @@ def export_ckpt(
@run.cli.entrypoint(name="generate", namespace="llm")
def generate(
path: Union[Path, str],
prompts: list[str],
trainer: nl.Trainer,
prompts: Optional[list[str]] = None,
encoder_prompts: Optional[list[str]] = None,
input_dataset: Optional[Union[pl.LightningDataModule, str]] = None,
params_dtype: torch.dtype = torch.bfloat16,
add_BOS: bool = False,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_batch_times_seqlen_threshold: int = 1000,
inference_params: Optional["CommonInferenceParams"] = None,
text_only: bool = False,
output_path: Optional[Union[Path, str]] = None,
) -> list[Union["InferenceRequest", str]]:
"""
Generates text using a NeMo LLM model.
Expand Down Expand Up @@ -725,6 +731,8 @@ def generate(
prompts (list[str]): The list of prompts to generate text for.
trainer (nl.Trainer): The trainer object.
encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None.
input_dataset (Optional[Union[pl.LightningDataModule, str]], optional): The input data module or jsonl file.
Test set will be used for generation for data modules. Defaults to None.
params_dtype (torch.dtype, optional): The data type of the model parameters. Defaults to torch.bfloat16.
add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False.
max_batch_size (int, optional): The maximum batch size. Defaults to 4.
Expand All @@ -734,31 +742,72 @@ def generate(
inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in
Mcore's CommonInferenceParams. Defaults to None.
text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False.
output_path (Optional[Union[Path, str]], optional): The path to save the generated text or test dataset
predictions. Defaults to None.
Returns:
list[Union["InferenceRequest", str]]: A list of generated text,
either as a string or as an InferenceRequest object.
"""
from nemo.collections.llm import inference

if input_dataset is not None:
input_path = input_dataset if isinstance(input_dataset, str) else input_dataset.test_path
with open(input_path) as f:
dataset = [json.loads(sample) for sample in f.readlines()]
inputs = [sample["input"] for sample in dataset]
elif prompts is not None:
inputs = prompts
else:
raise ValueError("Either prompts or input_dataset must be provided.")

inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
path=path,
trainer=trainer,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
)
results = inference.generate(

dp_size = trainer.strategy.distributed_sampler_kwargs['num_replicas']
dp_rank = trainer.strategy.distributed_sampler_kwargs['rank']
chunk_size = (len(inputs) + dp_size - 1) // dp_size
start_idx = dp_rank * chunk_size
end_idx = min(start_idx + chunk_size, len(inputs))
inputs_on_this_dp_rank = inputs[start_idx:end_idx]

results_on_this_dp_rank = inference.generate(
model=inference_wrapped_model,
tokenizer=mcore_tokenizer,
prompts=prompts,
prompts=inputs_on_this_dp_rank,
encoder_prompts=encoder_prompts,
add_BOS=add_BOS,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
)
gathered_results = [None] * dp_size

return [r.generated_text if text_only else r for r in results]
all_gather_object(
gathered_results,
[r.generated_text if text_only else r for r in results_on_this_dp_rank],
group=parallel_state.get_data_parallel_group(),
)
gathered_results = [result for sublist in gathered_results for result in sublist]

assert len(gathered_results) == len(inputs)

if output_path is not None and is_global_rank_zero():
with open(output_path, "w") as f:
for sample, pred in zip(dataset if input_dataset else inputs, gathered_results):
if type(sample) == dict:
sample["label"] = sample.pop("output", None)
sample["prediction"] = pred if text_only else pred.generated_text
elif type(sample) == str:
sample = {"input": sample, "prediction": pred if text_only else pred.generated_text}
f.write(json.dumps(sample) + "\n")
logging.info(f"Predictions written to {output_path}")

return gathered_results


def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def generate(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)

common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512)
common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512, top_k=1)

results = mcore_engine.generate(
prompts=prompts,
Expand Down

0 comments on commit b817158

Please sign in to comment.