Skip to content

Commit

Permalink
Remove config_path from convert_lit_checkpoint.py (#1056)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
rasbt and awaelchli authored Mar 8, 2024
1 parent 5e5a27a commit 3aa7beb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
9 changes: 5 additions & 4 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:


@torch.inference_mode()
def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path: Path) -> None:
config = Config.from_json(config_path)
def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
config = Config.from_json(checkpoint_dir / "lit_config.json")

output_path.parent.mkdir(parents=True, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "model.pth"

if "falcon" in config.name:
copy_fn = partial(copy_weights_falcon, config.name)
Expand All @@ -257,7 +258,7 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path
# initialize a new empty state dict to hold our new weights
sd = {}
with incremental_save(output_path) as saver:
lit_weights = lazy_load(checkpoint_path)
lit_weights = lazy_load(checkpoint_dir / "lit_model.pth")
lit_weights = lit_weights.get("model", lit_weights)
check_conversion_supported(lit_weights)
copy_fn(sd, lit_weights, saver=saver)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ def test_convert_lit_checkpoint(tmp_path):

ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128)
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "foo.ckpt"
config_path = tmp_path / "foo.json"
checkpoint_path = tmp_path / "lit_model.pth"
config_path = tmp_path / "lit_config.json"
torch.save(ours_model.state_dict(), checkpoint_path)
with open(config_path, "w") as fp:
json.dump(asdict(ours_config), fp)
output_path = tmp_path / "generated.bin"
output_dir = tmp_path / "out_dir"

convert_lit_checkpoint(checkpoint_path, output_path, config_path)
assert set(os.listdir(tmp_path)) == {"foo.ckpt", "foo.json", "generated.bin"}
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
assert set(os.listdir(tmp_path)) == {"lit_model.pth", "lit_config.json", "out_dir"}
assert os.path.isfile(output_dir / "model.pth")

# check checkpoint is unwrapped
torch.save({"model": ours_model.state_dict()}, checkpoint_path)
convert_lit_checkpoint(checkpoint_path, output_path, config_path)
converted_sd = torch.load(output_path)
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
converted_sd = torch.load(output_dir / "model.pth")
assert "model" not in converted_sd


Expand Down
21 changes: 9 additions & 12 deletions tutorials/convert_lit_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@ We provide a helpful script to convert models LitGPT models back to their equiva

```sh
python litgpt/scripts/convert_lit_checkpoint.py \
--checkpoint_path checkpoints/repo_id/lit_model.pth \
--output_path output_path/converted.pth \
--config_path checkpoints/repo_id/config.json
--checkpoint_dir checkpoint_dir \
--output_dir converted_dir
```

These paths are just placeholders, you will need to customize them based on which finetuning or pretraining script you ran and it's configuration.
These paths are just placeholders, you will need to customize them based on which finetuning or pretraining script you ran and its configuration.

### Loading converted LitGPT checkpoints into transformers

If you want to load the converted checkpoints into a `transformers` model, please make sure you copied the original `config.json` file into the folder that contains the `converted.pth` file saved via `--output_path` above.

For example,

```bash
cp checkpoints/repo_id/config.json output_path/config.json
cp checkpoints/repo_id/config.json converted/config.json
```

Then, you can load the checkpoint file in a Python session as follows:
Expand All @@ -30,9 +28,9 @@ import torch
from transformers import AutoModel


state_dict = torch.load("output_path/converted.pth")
state_dict = torch.load("output_dir/model.pth")
model = AutoModel.from_pretrained(
"output_path/", local_files_only=True, state_dict=state_dict
"output_dir/", local_files_only=True, state_dict=state_dict
)
```

Expand Down Expand Up @@ -105,9 +103,8 @@ python scripts/merge_lora.py \

```bash
python litgpt/scripts/convert_lit_checkpoint.py \
--checkpoint_path $finetuned_dir/final/lit_model.pth \
--output_path out/hf-tinyllama/converted_model.pth \
--config_path checkpoints/$repo_id/lit_config.json
--checkpoint_dir $finetuned_dir/final/ \
--output_dir out/hf-tinyllama/converted \
```


Expand All @@ -117,6 +114,6 @@ python litgpt/scripts/convert_lit_checkpoint.py \
import torch
from transformers import AutoModel

state_dict = torch.load('out/hf-tinyllama/converted_model.pth')
state_dict = torch.load('out/hf-tinyllama/converted/model.pth')
model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", state_dict=state_dict)
```

0 comments on commit 3aa7beb

Please sign in to comment.