diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index d1ceb462a2a8..196e89fdb33e 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -13,10 +13,10 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from pytorch_lightning.trainer.states import TrainerFn -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir import nemo.lightning as nl from nemo.lightning import io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig @@ -71,10 +71,10 @@ def setup_model_and_tokenizer( model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") trainer = trainer or io.load_context(path=ckpt_to_context_subdir(path), subpath="trainer") if tensor_parallel_size > 0: - trainer.strategy.tensor_model_parallel_size=tensor_parallel_size - trainer.devices=tensor_parallel_size - trainer.strategy.parallel_devices=[torch.device(f"cuda:{i}") for i in range(tensor_parallel_size)] - trainer.strategy.launcher.num_processes=len(trainer.strategy.parallel_devices) + trainer.strategy.tensor_model_parallel_size = tensor_parallel_size + trainer.devices = tensor_parallel_size + trainer.strategy.parallel_devices = [torch.device(f"cuda:{i}") for i in range(tensor_parallel_size)] + trainer.strategy.launcher.num_processes = len(trainer.strategy.parallel_devices) _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) inference_wrapped_model = GPTInferenceWrapper(