Skip to content

Commit

Permalink
change to cpu convert
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Oct 25, 2024
1 parent 4ea3505 commit dc7a173
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions scripts/checkpoint_converters/convert_nemo1_to_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,31 @@
from omegaconf import OmegaConf
from transformers import AutoTokenizer as HFAutoTokenizer

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.io.pl import TrainerContext
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

"""
Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format.
Example usage:
a. Convert a .nemo checkpoint in tp1
a. Convert a .nemo checkpoint
python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \
--input_path=Meta-Llama-3-8B.nemo \
--output_path=your_output_dir \
--model_id=meta-llama/Meta-Llama-3-8B
b. Convert a .nemo checkpoint in tp4
b. Convert a .nemo checkpoint
torchrun --nproc_per_node=4 /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \
--input_path=Mixtral-8x7B.nemo \
--output_path=your_output_dir \
--model_id=mistralai/Mixtral-8x7B-v0.1 \
--tp_size=4
c. Convert a model weight directory. The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file.
Please also provide tokenizer_library and tokenizer_path when loading from weight directory.
Expand All @@ -60,7 +59,6 @@ def get_args():
parser.add_argument("--model_id", type=str, default=None, required=True, help="Hugging Face model id for the model")
parser.add_argument("--tokenizer_path", type=str, default=None, required=False, help="Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails")
parser.add_argument("--tokenizer_library", type=str, default=None, required=False, help="Tokenizer library, e.g. `sentencepiece`, `megatron`. Defaults to `sentencepiece`")
parser.add_argument("--tp_size", type=int, default=1, required=False, help="TP size for loading the base model, increase if OOM")
args = parser.parse_args()
return args

Expand All @@ -69,11 +67,14 @@ def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel:
model_config_mapping = {
"meta-llama/Meta-Llama-3-8B": (llm.LlamaModel , llm.Llama3Config8B),
"mistralai/Mixtral-8x7B-v0.1": (llm.MixtralModel, llm.MixtralConfig8x7B),
"nvidia/nemotron-3-8b-base-4k": (llm.NemotronModel, llm.Nemotron3Config8B)
"nvidia/nemotron-3-8b-base-4k": (llm.NemotronModel, llm.Nemotron3Config8B),
"nemotron4-340b": (llm.NemotronModel, llm.Nemotron4Config340B),
}
if model_id not in model_config_mapping:
raise ValueError(f"Unsupported model_id: '{model_id}'. Please provide a valid model_id from {list(model_config_mapping.keys())}.")
model_cls, config_cls = model_config_mapping[model_id]
if config_cls is llm.Nemotron4Config340B:
config_cls.bf16=True
return model_cls(config_cls(), tokenizer=tokenizer)


Expand All @@ -84,11 +85,10 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer:
cfg = OmegaConf.load(f"{tmp_dir}/model_config.yaml")
tokenizer_lib = cfg.tokenizer.library
tokenizer_model = cfg.tokenizer.get("model") and cfg.tokenizer.get("model").split("nemo:", 1)[-1]
if is_global_rank_zero():
if tokenizer_model:
shutil.copy(f"{tmp_dir}/{tokenizer_model}", f"{tokenizer_tmp_dir}/{tokenizer_model}")
elif cfg.tokenizer.library=="huggingface":
HFAutoTokenizer.from_pretrained(cfg.tokenizer.type).save_pretrained(tokenizer_tmp_dir)
if tokenizer_model:
shutil.copy(f"{tmp_dir}/{tokenizer_model}", f"{tokenizer_tmp_dir}/{tokenizer_model}")
elif cfg.tokenizer.library=="huggingface":
HFAutoTokenizer.from_pretrained(cfg.tokenizer.type).save_pretrained(tokenizer_tmp_dir)
tokenizer_model = f"{tokenizer_tmp_dir}/{tokenizer_model}" if tokenizer_model else None
else:
if args.tokenizer_path: #not .nemo file, only weight dir need to specify tokenizer lib and path
Expand All @@ -107,33 +107,29 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer:


def main() -> None:
tokenizer_tmp_dir = "/tmp/nemo_tokenizer"
tokenizer = get_tokenizer(Path(args.input_path), Path(tokenizer_tmp_dir))
tokenizer_tmp_dir = Path("/tmp/nemo_tokenizer")
tokenizer_tmp_dir.mkdir(parents=True, exist_ok=True)
tokenizer = get_tokenizer(Path(args.input_path), tokenizer_tmp_dir)
model = get_nemo2_model(args.model_id, tokenizer=tokenizer)
model.optim = None

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp_size,
setup_optimizers=False,
init_model_parallel=False
trainer = Trainer(
devices=1,
accelerator="cpu",
strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed())
)

trainer = nl.Trainer(
devices=args.tp_size,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
trainer.strategy.connect(model)
trainer.strategy.setup_environment()
trainer.strategy._setup_optimizers = False
trainer.strategy._init_model_parallel = False
trainer.strategy.setup(trainer)
trainer.model.configure_model()
if not model.state_dict():
with _strategy_lib.megatron_cpu_init_context(model.config):
model.configure_model()

trainer.strategy.setup(trainer)

logging.info(f"loading checkpoint {args.input_path}")

sharded_state_dict = {"state_dict": trainer.model.sharded_state_dict()}
sharded_state_dict = {"state_dict": trainer.strategy.megatron_parallel.sharded_state_dict()}

for key in list(sharded_state_dict['state_dict'].keys()):
new_key = key.replace('module', 'model', 1)
Expand All @@ -157,11 +153,10 @@ def skip_fp8_load(x):
trainer.model.module.load_state_dict(model_ckpt['state_dict'])
trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path))

if is_global_rank_zero():
#Corresponding to Connector: on_import_ckpt
if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'):
trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(args.output_path), yaml_attrs=["model"])
#Corresponding to Connector: on_import_ckpt
if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'):
trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(args.output_path), yaml_attrs=["model"])

#remove tmp dir
if os.path.isdir(tokenizer_tmp_dir):
Expand Down

0 comments on commit dc7a173

Please sign in to comment.