From 91a5a5bb09848939bb2be27c29304c3131d340ee Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 12 Jun 2024 13:54:54 -0500 Subject: [PATCH] apply changes --- litgpt/scripts/convert_hf_checkpoint.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 3222661999..edca3f3850 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -337,10 +337,18 @@ def convert_hf_checkpoint( # Load the json file containing weight mapping pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" + model_safetensor_map_json_path = checkpoint_dir / "model.safetensors.index.json" if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file with open(pytorch_bin_map_json_path, encoding="utf-8") as json_map: bin_index = json.load(json_map) bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + elif model_safetensor_map_json_path.is_file(): + with open(model_safetensor_map_json_path, encoding="utf-8") as json_map: + bin_index = json.load(json_map) + bin_files = { + checkpoint_dir / Path(bin).with_suffix(".bin") + for bin in bin_index["weight_map"].values() + } else: bin_files = set(checkpoint_dir.glob("*.bin")) # some checkpoints serialize the training arguments