Skip to content

Commit

Permalink
Cherry pick llm.generate (#10998)
Browse files Browse the repository at this point in the history
* [NeMo-UX] Add llm.generate to nemo.collections.llm (#10471)

* Add llm.generate

Signed-off-by: Hemil Desai <[email protected]>

* Remove comment

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fix launching with python

Signed-off-by: Hemil Desai <[email protected]>

* PR feedback

Signed-off-by: Hemil Desai <[email protected]>

* PR feedback

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Add assert cp

Signed-off-by: Hemil Desai <[email protected]>

* Add example script

Signed-off-by: Hemil Desai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
hemildesai and hemildesai authored Oct 22, 2024
1 parent 931cfbf commit 945cb6b
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 3 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
safe_import("transformer_engine")

from nemo.collections.llm import peft
from nemo.collections.llm.api import export_ckpt, finetune, import_ckpt, pretrain, train, validate
from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
Expand Down Expand Up @@ -185,6 +185,7 @@
"pretrain",
"validate",
"finetune",
"generate",
"mock",
"squad",
"dolly",
Expand Down
43 changes: 41 additions & 2 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,23 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import nemo_run as run
import pytorch_lightning as pl
import torch
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging

if TYPE_CHECKING:
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest


TokenizerType = Any


Expand Down Expand Up @@ -384,7 +391,7 @@ def deploy(
try:
logging.info("REST service will be started.")
uvicorn.run(
'nemo.deploy.service.rest_model_api:app',
"nemo.deploy.service.rest_model_api:app",
host=rest_service_http_address,
port=rest_service_port,
reload=True,
Expand Down Expand Up @@ -425,6 +432,38 @@ def export_ckpt(
return io.export_ckpt(path, target, output_path, overwrite, load_connector)


@run.cli.entrypoint(name="generate", namespace="llm")
def generate(
path: Union[Path, str],
prompts: list[str],
trainer: Optional[nl.Trainer] = None,
params_dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_batch_times_seqlen_threshold: int = 1000,
inference_params: Optional["CommonInferenceParams"] = None,
text_only: bool = False,
) -> list[Union["InferenceRequest", str]]:
from nemo.collections.llm import inference

inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
path=path,
trainer=trainer,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
)
results = inference.generate(
model=inference_wrapped_model,
tokenizer=mcore_tokenizer,
prompts=prompts,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
)

return [r.generated_text if text_only else r for r in results]


def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None:
if tokenizer == "data":
_set_with_io(model, "tokenizer", data.tokenizer)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nemo.collections.llm.inference.base import MCoreTokenizerWrappper, generate, setup_model_and_tokenizer

__all__ = ["MCoreTokenizerWrappper", "setup_model_and_tokenizer", "generate"]
107 changes: 107 additions & 0 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from pytorch_lightning.trainer.states import TrainerFn

import nemo.lightning as nl
from nemo.lightning import io
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig


# We need this wrapper since mcore generate uses tokenizer.detokenize, tokenizer.tokenize to encode and decode prompts
class MCoreTokenizerWrappper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.eod = tokenizer.eod
self.vocab_size = tokenizer.vocab_size

def detokenize(self, tokens):
return self.tokenizer.ids_to_text(tokens)

def tokenize(self, prompt):
return self.tokenizer.text_to_ids(prompt)


# TODO: Move to lightning Fabric API.
def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule):
assert isinstance(trainer.strategy, MegatronStrategy), "Only MegatronStrategy is supported for trainer.strategy."
assert trainer.strategy.context_parallel_size <= 1, "Context parallelism is not supported for inference."
restore_config = RestoreConfig(
path=path,
load_model_state=True,
load_optim_state=False,
)
trainer.strategy.restore_config = restore_config
trainer.ckpt_path = None
trainer.strategy.connect(model)
if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(lambda: None, trainer=trainer)
trainer.strategy.setup_environment()

if not model.state_dict():
model.configure_model()

trainer.state.fn = TrainerFn.TESTING
trainer.strategy.setup_megatron_parallel(trainer=trainer)
trainer.strategy.trainer = trainer
trainer.strategy.selective_restore()


def setup_model_and_tokenizer(
path: Path,
trainer: Optional[nl.Trainer] = None,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=path, subpath="model")
trainer = trainer or io.load_context(path=path, subpath="trainer")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model.module.module.module
inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=model.tokenizer.vocab_size,
),
)

return inference_wrapped_model, MCoreTokenizerWrappper(model.tokenizer)


def generate(
model: GPTInferenceWrapper,
tokenizer: MCoreTokenizerWrappper,
prompts: list[str],
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_params: Optional[CommonInferenceParams] = None,
) -> dict:
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer)
mcore_engine = MCoreEngine(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)

common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512)

results = mcore_engine.generate(
prompts=prompts,
common_inference_params=common_inference_params,
)

return results
66 changes: 66 additions & 0 deletions scripts/llm/llama3_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This script is just an example of using NeMo checkpoints for generating outputs and is subject to change without notice.

import os

import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams

import nemo.lightning as nl
from nemo.collections.llm import api

if __name__ == "__main__":
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
store_optimizer_states=False,
)

trainer = nl.Trainer(
accelerator="gpu",
devices=2,
num_nodes=1,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=False,
),
)
prompts = [
"Hello, how are you?",
"How many r's are in the word 'strawberry'?",
"Which number is bigger? 10.119 or 10.19?",
]
results = api.generate(
path=os.path.join(os.environ["NEMO_HOME"], "models", "meta-llama/Meta-Llama-3-8B"),
prompts=prompts,
trainer=trainer,
inference_params=CommonInferenceParams(temperature=0.1, top_k=10, num_tokens_to_generate=512),
text_only=True,
)
if torch.distributed.get_rank() == 0:
for i, r in enumerate(results):
print(prompts[i])
print("*" * 50)
print(r)
print("\n\n")

0 comments on commit 945cb6b

Please sign in to comment.