Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load config from checkpoint_dir #755

Merged
merged 8 commits into from
Nov 20, 2023
9 changes: 9 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
json_kwargs.update(kwargs)
return cls(**json_kwargs)

@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`."""
if (config_path := path / "lit_config.json").is_file():
return cls.from_json(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")

@property
def mlp_class(self) -> Type:
# `self._mlp_class` cannot be the type to keep the config json serializable
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepare_redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def prepare(
match: str = "",
) -> None:
"""Prepare the "Red Pajama" dataset. We assume tokenizer has been trained."""
config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_checkpoint(checkpoint_dir)

prepare_fn = prepare_sample if sample else prepare_full
prepare_fn(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,35 @@ def test_nonexisting_name():

with pytest.raises(ValueError, match="not a supported"):
Config.from_name("foobar")


def test_from_checkpoint(tmp_path):
from lit_gpt import Config

# 1. Neither `lit_config.py` nor matching config exists.
with pytest.raises(FileNotFoundError, match="neither 'lit_config.json' nor matching config exists"):
Config.from_checkpoint(tmp_path / "non_existing_checkpoint")

# 2. If `lit_config.py` doesn't exists, but there is a matching config in `lit_gpt/config.py`.
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 2048
assert config.n_layer == 6

# 3. If only `lit_config.py` exists.
config_data = {"name": "pythia-70m", "block_size": 24, "n_layer": 2}
with open(tmp_path / "lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path)
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2

# 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config
(tmp_path / "pythia-70m").mkdir()
with open(tmp_path / "pythia-70m/lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2