diff --git a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama.py b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama.py index 7df0e13..eb1c1cd 100644 --- a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama.py +++ b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama.py @@ -47,8 +47,7 @@ def convert_checkpoint(p): "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0), "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0), "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0), - "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0), - "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0), + "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj_up_proj.weight", 0, 0), "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0), "model.language_model.encoder.final_layernorm.weight": (0, "model.norm.weight", None, 0), "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0), @@ -78,9 +77,15 @@ def convert_checkpoint(p): v = model_llama[f'model.layers.{i}.self_attn.v_proj.weight'] model_llama[f'model.layers.{i}.self_attn.query_key_value.weight'] = torch.cat([q, k, v], dim=0) + gate_proj = model_llama[f'model.layers.{i}.mlp.gate_proj.weight'] + up_proj = model_llama[f'model.layers.{i}.mlp.up_proj.weight'] + model_llama[f'model.layers.{i}.mlp.gate_proj_up_proj.weight'] = torch.cat([gate_proj, up_proj], dim=0) + model_llama.pop(f'model.layers.{i}.self_attn.q_proj.weight') model_llama.pop(f'model.layers.{i}.self_attn.k_proj.weight') model_llama.pop(f'model.layers.{i}.self_attn.v_proj.weight') + model_llama.pop(f'model.layers.{i}.mlp.gate_proj.weight') + model_llama.pop(f'model.layers.{i}.mlp.up_proj.weight') for p in range(PP): for i in range(TP): diff --git a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama_70b.py b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama_70b.py index 02a82cf..d666514 100644 --- a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama_70b.py +++ b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_llama_70b.py @@ -83,8 +83,7 @@ def convert_checkpoint(p, args, config): "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0), "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0), "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0), - "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0), - "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0), + "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj_up_proj.weight", 0, 0), "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0), "model.language_model.encoder.final_layernorm.weight": (0, "model.norm.weight", None, 0), "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0), @@ -102,7 +101,7 @@ def convert_checkpoint(p, args, config): model_llama = {} for _path in model_paths: print(f'Loading {_path}') - ts = torch.load(_path) + ts = torch.load(_path, map_location='cpu') model_llama.update(ts) print(len(model_llama)) @@ -117,9 +116,15 @@ def convert_checkpoint(p, args, config): model_llama[f'model.layers.{i}.self_attn.query.weight'] = q model_llama[f'model.layers.{i}.self_attn.key_value.weight'] = torch.cat([k, v], dim=0) + gate_proj = model_llama[f'model.layers.{i}.mlp.gate_proj.weight'] + up_proj = model_llama[f'model.layers.{i}.mlp.up_proj.weight'] + model_llama[f'model.layers.{i}.mlp.gate_proj_up_proj.weight'] = torch.cat([gate_proj, up_proj], dim=0) + model_llama.pop(f'model.layers.{i}.self_attn.q_proj.weight') model_llama.pop(f'model.layers.{i}.self_attn.k_proj.weight') model_llama.pop(f'model.layers.{i}.self_attn.v_proj.weight') + model_llama.pop(f'model.layers.{i}.mlp.gate_proj.weight') + model_llama.pop(f'model.layers.{i}.mlp.up_proj.weight') for i in range(TP): diff --git a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_llama.py b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_llama.py index edb8db0..f40458c 100644 --- a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_llama.py +++ b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_llama.py @@ -30,6 +30,7 @@ def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_head param = param.view(*input_shape) return param + def get_tp_pp_degree(path_to_checkpoints): dir_name = PurePath(path_to_checkpoints).name @@ -125,8 +126,7 @@ def convert_checkpoint(config_file, "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0), "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0), "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0), - "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0), - "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0), + "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj_up_proj.weight", 0, 0), "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0), "final_layernorm.weight": (0, "model.norm.weight", None, 0), "output_layer.weight": (1, "lm_head.weight", 0, 0), # this is shared @@ -217,6 +217,13 @@ def convert_checkpoint(config_file, hf_model[hf_key_q], hf_model[hf_key_k], hf_model[hf_key_v] = torch.split(hf_model[hf_key], size_per_seg, dim=0) hf_model.pop(hf_key) + if "dense_h_to_4h" in k: + hf_key_gate_proj = f"{br_key}{ln_idx}.mlp.gate_proj.weight" + hf_key_up_proj = f"{br_key}{ln_idx}.mlp.up_proj.weight" + size_per_seg = hf_model[hf_key].shape[0] // 2 + hf_model[hf_key_gate_proj], hf_model[hf_key_up_proj] = torch.split(hf_model[hf_key], size_per_seg, dim=0) + hf_model.pop(hf_key) + path = Path(output_path) path.mkdir(parents=True, exist_ok=True) torch.save(hf_model, str(path)+"/pytorch_model.bin") diff --git a/nemo/nemo/collections/nlp/parts/serialization.py b/nemo/nemo/collections/nlp/parts/serialization.py index 665389f..27f85e9 100644 --- a/nemo/nemo/collections/nlp/parts/serialization.py +++ b/nemo/nemo/collections/nlp/parts/serialization.py @@ -102,7 +102,7 @@ def save(data, path, should_save=True, partition=0, num_partitions=1, saver=None Default: False """ if saver is None: - saver = SimplerSaver() + saver = SimpleSaver() ref_data = _rewrite_data(_get_tensors_folder(path), data, should_save, partition, num_partitions, saver) if should_save: saver.add_save_task(ref_data, path)