Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Oct 22, 2024
1 parent 851686f commit 302008e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
)
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from pytorch_lightning.trainer.states import TrainerFn
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir

import nemo.lightning as nl
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

Expand Down Expand Up @@ -71,10 +71,10 @@ def setup_model_and_tokenizer(
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
trainer = trainer or io.load_context(path=ckpt_to_context_subdir(path), subpath="trainer")
if tensor_parallel_size > 0:
trainer.strategy.tensor_model_parallel_size=tensor_parallel_size
trainer.devices=tensor_parallel_size
trainer.strategy.parallel_devices=[torch.device(f"cuda:{i}") for i in range(tensor_parallel_size)]
trainer.strategy.launcher.num_processes=len(trainer.strategy.parallel_devices)
trainer.strategy.tensor_model_parallel_size = tensor_parallel_size
trainer.devices = tensor_parallel_size
trainer.strategy.parallel_devices = [torch.device(f"cuda:{i}") for i in range(tensor_parallel_size)]
trainer.strategy.launcher.num_processes = len(trainer.strategy.parallel_devices)
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

inference_wrapped_model = GPTInferenceWrapper(
Expand Down

0 comments on commit 302008e

Please sign in to comment.