diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 3222661999..edca3f3850 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -337,10 +337,18 @@ def convert_hf_checkpoint( # Load the json file containing weight mapping pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" + model_safetensor_map_json_path = checkpoint_dir / "model.safetensors.index.json" if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file with open(pytorch_bin_map_json_path, encoding="utf-8") as json_map: bin_index = json.load(json_map) bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + elif model_safetensor_map_json_path.is_file(): + with open(model_safetensor_map_json_path, encoding="utf-8") as json_map: + bin_index = json.load(json_map) + bin_files = { + checkpoint_dir / Path(bin).with_suffix(".bin") + for bin in bin_index["weight_map"].values() + } else: bin_files = set(checkpoint_dir.glob("*.bin")) # some checkpoints serialize the training arguments