Skip to content

Commit

Permalink
Merge branch 'main' into onur/log-props
Browse files Browse the repository at this point in the history
  • Loading branch information
oyilmaz-nvidia authored Oct 28, 2024
2 parents f42954e + 5d3dadb commit ac2403f
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 54 deletions.
19 changes: 11 additions & 8 deletions examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ trainer:
model:
tensor_model_parallel_size: 1 # Pruning currently only supports tensor_model_parallel_size=1
pipeline_model_parallel_size: 1
restore_from_path: llama3.1-8b-base.nemo # Nemo file path
sequence_parallel: false # Sequence parallelism is not supported with pipeline parallelism
restore_from_path: llama3.1-8b-instruct.nemo # Nemo file path

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'

prune:
calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 512 # number of samples used for calibration
ffn_hidden_size: 3584 # ffn_hidden_size in the pruned model, ffn_hidden_size // 4
num_attention_heads: 8 # num_attention_heads in the pruned model, num_attention_heads // 4
num_query_groups: 4 # num_query_groups in the pruned model, num_query_groups // 2
hidden_size: 2048 # hidden_size in the pruned model, hidden_size // 2
calib_dataset: wikitext # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 1024 # number of samples used for calibration
# pruning constraints (null means no pruning)
ffn_hidden_size: 9216 # ffn_hidden_size in the pruned model
num_attention_heads: null # num_attention_heads in the pruned model
num_query_groups: null # num_query_groups in the pruned model
hidden_size: 3072 # hidden_size (embedding size) in the pruned model
num_layers: null # num_layers (depth) in the pruned model

export:
save_path: llama3.1-8b-base-pruned.nemo # Path where the pruned model will be saved
save_path: llama3.1-8b-instruct-pruned.nemo # Path where the pruned model will be saved
29 changes: 15 additions & 14 deletions examples/nlp/language_modeling/megatron_gpt_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@
Example usage:
```
python examples/nlp/language_modeling/megatron_gpt_prune.py \
model.restore_from_path=llama3.1-8b-base.nemo \
model.restore_from_path=llama3.1-8b-instruct.nemo \
model.tensor_model_parallel_size=1 \
model.pipeline_model_parallel_size=8 \
trainer.num_nodes=1 \
trainer.precision=bf16 \
trainer.devices=8 \
prune.ffn_hidden_size=3584 \
prune.num_attention_heads=8 \
prune.num_query_groups=4 \
prune.hidden_size=2048 \
export.save_path=llama3.1-8b-base-pruned.nemo
prune.ffn_hidden_size=9216 \
prune.num_attention_heads=null \
prune.num_query_groups=null \
prune.hidden_size=3072 \
export.save_path=llama3.1-8b-instruct-pruned.nemo
```
where tensor_model_parallel_size must be 1 because of the current prune API limitation
"""


def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512):
def get_calib_data_iter(data="wikitext", batch_size=64, calib_size=512, max_sequence_length=512):
if data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
Expand All @@ -73,18 +73,12 @@ def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max

@hydra_runner(config_path="conf", config_name="megatron_gpt_prune")
def main(cfg) -> None:
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for the pruning.")

# Overwrite model config with the one from the model checkpoint and apply pruning modifications
model_cfg = load_config(cfg.model.restore_from_path)
model_cfg.update(cfg.model)
model_cfg.name = "modelopt" # Use modelopt transformer spec for pruning

assert cfg.model.tensor_model_parallel_size == 1, "Pruning currently only supports tensor_model_parallel_size=1"
assert (
not hasattr(cfg.model, "sequence_parallel") or not cfg.model.sequence_parallel
), "Pruning currently does not support sequence parallelism"

trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
model = MegatronGPTModel.restore_from(
Expand Down Expand Up @@ -112,7 +106,13 @@ def forward_loop(model):
constraints={
"export_config": {
k: cfg.prune.get(k)
for k in ["ffn_hidden_size", "num_attention_heads", "num_query_groups", "hidden_size"]
for k in [
"ffn_hidden_size",
"num_attention_heads",
"num_query_groups",
"hidden_size",
"num_layers",
]
if cfg.prune.get(k) is not None
},
},
Expand All @@ -121,6 +121,7 @@ def forward_loop(model):
)

model_pruned.save_to(cfg.export.save_path)
print(f"Pruned model saved to {cfg.export.save_path}")


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,12 @@ def text_to_ids(self, text):
ids = self.tokens_to_ids(tokens)
return ids

def ids_to_text(self, ids):
def ids_to_text(self, ids, remove_special_tokens=True):
tokens = self.ids_to_tokens(ids)
tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens]
if remove_special_tokens:
tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens]
else:
tokens_clean = tokens
text = self.tokens_to_text(tokens_clean)
return text

Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ def generate(
path: Union[Path, str],
prompts: list[str],
trainer: nl.Trainer,
encoder_prompts: Optional[list[str]] = None,
params_dtype: torch.dtype = torch.bfloat16,
add_BOS: bool = False,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_batch_times_seqlen_threshold: int = 1000,
Expand All @@ -456,6 +458,8 @@ def generate(
model=inference_wrapped_model,
tokenizer=mcore_tokenizer,
prompts=prompts,
encoder_prompts=encoder_prompts,
add_BOS=add_BOS,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
Expand Down
22 changes: 22 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -310,6 +312,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor:
# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = self.module
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapper_config = InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=self.tokenizer.vocab_size,
)

model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config)
return model_inference_wrapper

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
Expand Down
63 changes: 35 additions & 28 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from megatron.core.transformer.module import MegatronModule
from pytorch_lightning.trainer.states import TrainerFn

import nemo.lightning as nl
Expand All @@ -37,19 +41,31 @@
from nemo.lightning.pytorch.strategies.utils import RestoreConfig


# We need this wrapper since mcore generate uses tokenizer.detokenize, tokenizer.tokenize to encode and decode prompts
# We need this wrapper since mcore generate uses methods/properties such as tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts
class MCoreTokenizerWrappper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.eod = tokenizer.eod
self.vocab_size = tokenizer.vocab_size

def detokenize(self, tokens):
return self.tokenizer.ids_to_text(tokens)
def detokenize(self, tokens, remove_special_tokens=False):
return self.tokenizer.ids_to_text(tokens, remove_special_tokens)

def tokenize(self, prompt):
return self.tokenizer.text_to_ids(prompt)

@property
def additional_special_tokens_ids(self):
return self.tokenizer.additional_special_tokens_ids

@property
def bos(self):
return self.tokenizer.bos_id

@property
def pad(self):
return self.tokenizer.pad_id


# TODO: Move to lightning Fabric API.
def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule):
Expand Down Expand Up @@ -101,41 +117,30 @@ def setup_model_and_tokenizer(
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
) -> tuple[MegatronModule, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=model.tokenizer.vocab_size,
),
)

inference_wrapped_model = model.get_inference_wrapper(params_dtype, inference_batch_times_seqlen_threshold)
return inference_wrapped_model, MCoreTokenizerWrappper(model.tokenizer)


def generate(
model: GPTInferenceWrapper,
model: AbstractModelInferenceWrapper,
tokenizer: MCoreTokenizerWrappper,
prompts: list[str],
encoder_prompts: Optional[list[str]] = None,
add_BOS: bool = False,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_params: Optional[CommonInferenceParams] = None,
) -> dict:
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer)
if encoder_prompts is not None:
text_generation_controller = EncoderDecoderTextGenerationController(
inference_wrapped_model=model, tokenizer=tokenizer
)
else:
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer)
mcore_engine = MCoreEngine(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)
Expand All @@ -144,6 +149,8 @@ def generate(

results = mcore_engine.generate(
prompts=prompts,
add_BOS=add_BOS,
encoder_prompts=encoder_prompts,
common_inference_params=common_inference_params,
)

Expand Down
22 changes: 22 additions & 0 deletions nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -258,6 +260,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor:
# This is to get the MCore model required in T5InferenceWrapper.
mcore_model = self.module
while mcore_model:
if type(mcore_model) is MCoreT5Model:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreT5Model:
raise ValueError("Exact MCoreT5Model instance not found in the model structure.")

inference_wrapper_config = InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=self.tokenizer.vocab_size,
)

model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config)
return model_inference_wrapper

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/multimodal/data/energon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig):

@dataclass
class MultiModalSampleConfig:
image_token: ImageToken = ImageToken()
image_token: ImageToken = field(default_factory=ImageToken)
ignore_place_holder: int = -100
conversation_template_config: LLaVATemplateConfig = LLaVATemplateConfig()
conversation_template_config: LLaVATemplateConfig = field(default_factory=LLaVATemplateConfig)
image_following_text: bool = True
Loading

0 comments on commit ac2403f

Please sign in to comment.