diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 67bd61fe89..bd5cfa0a8e 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -239,10 +239,11 @@ def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: @torch.inference_mode() -def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path: Path) -> None: - config = Config.from_json(config_path) +def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: + config = Config.from_json(checkpoint_dir / "lit_config.json") - output_path.parent.mkdir(parents=True, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "model.pth" if "falcon" in config.name: copy_fn = partial(copy_weights_falcon, config.name) @@ -257,7 +258,7 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path # initialize a new empty state dict to hold our new weights sd = {} with incremental_save(output_path) as saver: - lit_weights = lazy_load(checkpoint_path) + lit_weights = lazy_load(checkpoint_dir / "lit_model.pth") lit_weights = lit_weights.get("model", lit_weights) check_conversion_supported(lit_weights) copy_fn(sd, lit_weights, saver=saver) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index f7995c79f2..8940940f72 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -20,20 +20,21 @@ def test_convert_lit_checkpoint(tmp_path): ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) ours_model = GPT(ours_config) - checkpoint_path = tmp_path / "foo.ckpt" - config_path = tmp_path / "foo.json" + checkpoint_path = tmp_path / "lit_model.pth" + config_path = tmp_path / "lit_config.json" torch.save(ours_model.state_dict(), checkpoint_path) with open(config_path, "w") as fp: json.dump(asdict(ours_config), fp) - output_path = tmp_path / "generated.bin" + output_dir = tmp_path / "out_dir" - convert_lit_checkpoint(checkpoint_path, output_path, config_path) - assert set(os.listdir(tmp_path)) == {"foo.ckpt", "foo.json", "generated.bin"} + convert_lit_checkpoint(checkpoint_path.parent, output_dir) + assert set(os.listdir(tmp_path)) == {"lit_model.pth", "lit_config.json", "out_dir"} + assert os.path.isfile(output_dir / "model.pth") # check checkpoint is unwrapped torch.save({"model": ours_model.state_dict()}, checkpoint_path) - convert_lit_checkpoint(checkpoint_path, output_path, config_path) - converted_sd = torch.load(output_path) + convert_lit_checkpoint(checkpoint_path.parent, output_dir) + converted_sd = torch.load(output_dir / "model.pth") assert "model" not in converted_sd diff --git a/tutorials/convert_lit_models.md b/tutorials/convert_lit_models.md index d4b88cf6fd..beba3b32e4 100644 --- a/tutorials/convert_lit_models.md +++ b/tutorials/convert_lit_models.md @@ -6,21 +6,19 @@ We provide a helpful script to convert models LitGPT models back to their equiva ```sh python litgpt/scripts/convert_lit_checkpoint.py \ - --checkpoint_path checkpoints/repo_id/lit_model.pth \ - --output_path output_path/converted.pth \ - --config_path checkpoints/repo_id/config.json + --checkpoint_dir checkpoint_dir \ + --output_dir converted_dir ``` -These paths are just placeholders, you will need to customize them based on which finetuning or pretraining script you ran and it's configuration. +These paths are just placeholders, you will need to customize them based on which finetuning or pretraining script you ran and its configuration. ### Loading converted LitGPT checkpoints into transformers -If you want to load the converted checkpoints into a `transformers` model, please make sure you copied the original `config.json` file into the folder that contains the `converted.pth` file saved via `--output_path` above. For example, ```bash -cp checkpoints/repo_id/config.json output_path/config.json +cp checkpoints/repo_id/config.json converted/config.json ``` Then, you can load the checkpoint file in a Python session as follows: @@ -30,9 +28,9 @@ import torch from transformers import AutoModel -state_dict = torch.load("output_path/converted.pth") +state_dict = torch.load("output_dir/model.pth") model = AutoModel.from_pretrained( - "output_path/", local_files_only=True, state_dict=state_dict + "output_dir/", local_files_only=True, state_dict=state_dict ) ``` @@ -105,9 +103,8 @@ python scripts/merge_lora.py \ ```bash python litgpt/scripts/convert_lit_checkpoint.py \ - --checkpoint_path $finetuned_dir/final/lit_model.pth \ - --output_path out/hf-tinyllama/converted_model.pth \ - --config_path checkpoints/$repo_id/lit_config.json + --checkpoint_dir $finetuned_dir/final/ \ + --output_dir out/hf-tinyllama/converted \ ``` @@ -117,6 +114,6 @@ python litgpt/scripts/convert_lit_checkpoint.py \ import torch from transformers import AutoModel -state_dict = torch.load('out/hf-tinyllama/converted_model.pth') +state_dict = torch.load('out/hf-tinyllama/converted/model.pth') model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", state_dict=state_dict) ```