Skip to content

Commit

Permalink
change to default bf16, add rc2 patch and logging
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Oct 30, 2024
1 parent dc7a173 commit 746611b
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions scripts/checkpoint_converters/convert_nemo1_to_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.io.pl import TrainerContext
from nemo.utils import logging
import torch

"""
Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format.
Expand All @@ -30,7 +31,7 @@
--model_id=meta-llama/Meta-Llama-3-8B
b. Convert a .nemo checkpoint
torchrun --nproc_per_node=4 /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \
python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \
--input_path=Mixtral-8x7B.nemo \
--output_path=your_output_dir \
--model_id=mistralai/Mixtral-8x7B-v0.1 \
Expand All @@ -47,7 +48,7 @@
"""

def get_args():
parser = ArgumentParser(description="Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format.")
parser = ArgumentParser(description="Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. This script may download from Hugging Face, make sure you have access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)")
parser.add_argument(
"--input_path",
type=str,
Expand All @@ -73,9 +74,8 @@ def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel:
if model_id not in model_config_mapping:
raise ValueError(f"Unsupported model_id: '{model_id}'. Please provide a valid model_id from {list(model_config_mapping.keys())}.")
model_cls, config_cls = model_config_mapping[model_id]
if config_cls is llm.Nemotron4Config340B:
config_cls.bf16=True
return model_cls(config_cls(), tokenizer=tokenizer)
# nemo1 ckpts are bf16
return model_cls(config_cls(bf16=True, params_dtype=torch.bfloat16), tokenizer=tokenizer)


def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer:
Expand Down Expand Up @@ -149,18 +149,24 @@ def skip_fp8_load(x):
else:
model_ckpt = trainer.strategy.checkpoint_io.load_checkpoint(args.input_path, sharded_state_dict, None)

logging.info(f"Saving checkpoint to {args.output_path}")
model_ckpt['state_dict'] = {k.replace('model', 'module', 1): v for k, v in model_ckpt['state_dict'].items()}
trainer.model.module.load_state_dict(model_ckpt['state_dict'])
trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path))
if getattr(trainer.strategy, "async_save", False):
trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)

#Corresponding to Connector: on_import_ckpt
if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'):
trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(args.output_path), yaml_attrs=["model"])
yaml_attrs = ["model"] if "nemotron" not in args.model_id.lower() else [] #Currently, doesn't support producing nemotron's model.yaml
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(args.output_path), yaml_attrs=yaml_attrs)

#remove tmp dir
if os.path.isdir(tokenizer_tmp_dir):
shutil.rmtree(tokenizer_tmp_dir)

logging.info(f"NeMo 2.0 checkpoint saved at {args.output_path}")

if __name__ == '__main__':
args = get_args()
Expand Down

0 comments on commit 746611b

Please sign in to comment.