Skip to content

Commit

Permalink
apply changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jun 12, 2024
1 parent 8f65463 commit 91a5a5b
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,18 @@ def convert_hf_checkpoint(

# Load the json file containing weight mapping
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
model_safetensor_map_json_path = checkpoint_dir / "model.safetensors.index.json"
if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file
with open(pytorch_bin_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
elif model_safetensor_map_json_path.is_file():
with open(model_safetensor_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {
checkpoint_dir / Path(bin).with_suffix(".bin")
for bin in bin_index["weight_map"].values()
}
else:
bin_files = set(checkpoint_dir.glob("*.bin"))
# some checkpoints serialize the training arguments
Expand Down

0 comments on commit 91a5a5b

Please sign in to comment.