Skip to content

Commit

Permalink
Load config from checkpoint_dir if weights are not required (#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Nov 20, 2023
1 parent f6174d9 commit c85bf01
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
9 changes: 9 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
json_kwargs.update(kwargs)
return cls(**json_kwargs)

@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`."""
if (config_path := path / "lit_config.json").is_file():
return cls.from_json(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")

@property
def mlp_class(self) -> Type:
# `self._mlp_class` cannot be the type to keep the config json serializable
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepare_redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def prepare(
match: str = "",
) -> None:
"""Prepare the "Red Pajama" dataset. We assume tokenizer has been trained."""
config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_checkpoint(checkpoint_dir)

prepare_fn = prepare_sample if sample else prepare_full
prepare_fn(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,35 @@ def test_nonexisting_name():

with pytest.raises(ValueError, match="not a supported"):
Config.from_name("foobar")


def test_from_checkpoint(tmp_path):
from lit_gpt import Config

# 1. Neither `lit_config.py` nor matching config exists.
with pytest.raises(FileNotFoundError, match="neither 'lit_config.json' nor matching config exists"):
Config.from_checkpoint(tmp_path / "non_existing_checkpoint")

# 2. If `lit_config.py` doesn't exists, but there is a matching config in `lit_gpt/config.py`.
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 2048
assert config.n_layer == 6

# 3. If only `lit_config.py` exists.
config_data = {"name": "pythia-70m", "block_size": 24, "n_layer": 2}
with open(tmp_path / "lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path)
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2

# 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config
(tmp_path / "pythia-70m").mkdir()
with open(tmp_path / "pythia-70m/lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2

0 comments on commit c85bf01

Please sign in to comment.