From 66646b83737d9a0facb1d8d714e0424fc86ec21a Mon Sep 17 00:00:00 2001 From: cuichenx Date: Mon, 1 Jul 2024 21:32:33 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: cuichenx --- .../convert_gpt_nemo_to_mcore.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py index d60da66f2c77..1f8c69b5b240 100644 --- a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py +++ b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py @@ -128,9 +128,9 @@ def build_key_mapping(nemo_cfg): f"{model_str}.decoder.final_layernorm.weight": "model.language_model.encoder.final_layernorm.weight", } if has_layernorm_bias: - mcore_to_nemo_mapping[ - f"{model_str}.decoder.final_layernorm.bias" - ] = "model.language_model.encoder.final_layernorm.bias" + mcore_to_nemo_mapping[f"{model_str}.decoder.final_layernorm.bias"] = ( + "model.language_model.encoder.final_layernorm.bias" + ) if not nemo_cfg.get("share_embeddings_and_output_weights", True): mcore_to_nemo_mapping[f"{model_str}.output_layer.weight"] = "model.language_model.output_layer.weight" @@ -138,9 +138,9 @@ def build_key_mapping(nemo_cfg): if nemo_cfg.get("position_embedding_type", 'learned_absolute') == 'rope': mcore_to_nemo_mapping[f"{model_str}.rotary_pos_emb.inv_freq"] = "model.language_model.rotary_pos_emb.inv_freq" else: - mcore_to_nemo_mapping[ - f"{model_str}.embedding.position_embeddings.weight" - ] = "model.language_model.embedding.position_embeddings.weight" + mcore_to_nemo_mapping[f"{model_str}.embedding.position_embeddings.weight"] = ( + "model.language_model.embedding.position_embeddings.weight" + ) nemo_prefix = "model.language_model.encoder.layers" mcore_prefix = f"{model_str}.decoder.layers" @@ -338,5 +338,7 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False, ignore_if_missing=t try: run_sanity_checks(input_nemo_file, output_nemo_file, cpu_only=cpu_only, ignore_if_missing=ignore_if_missing) except torch.cuda.OutOfMemoryError: - logging.info("✅ Conversion was successful, but could not run sanity check due to torch.cuda.OutOfMemoryError.") + logging.info( + "✅ Conversion was successful, but could not run sanity check due to torch.cuda.OutOfMemoryError." + ) logging.info("Please run the script with the same command again to run sanity check.")