Skip to content

Commit

Permalink
Merge branch 'Lightning-AI:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
murdadesmaeeli authored Nov 20, 2023
2 parents 9233fdc + 9690b70 commit e4ec206
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 6 deletions.
1 change: 1 addition & 0 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def main(
tokenizer = Tokenizer(checkpoint_dir)
system_prompt, stop_tokens = prompt_config(checkpoint_dir, tokenizer)

L.seed_everything(1234)
while True:
try:
prompt = input(">> Prompt: ")
Expand Down
1 change: 1 addition & 0 deletions generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def main(
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

L.seed_everything(1234)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
Expand Down
1 change: 1 addition & 0 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def main(
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

L.seed_everything(1234)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
Expand Down
2 changes: 1 addition & 1 deletion generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def main(
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

L.seed_everything(1234)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens

L.seed_everything(1234)
for i in range(num_samples):
with fabric.init_tensor():
# enable the kv cache
Expand Down
1 change: 1 addition & 0 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def main(
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

L.seed_everything(1234)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
Expand Down
1 change: 1 addition & 0 deletions generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main(
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

L.seed_everything(1234)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
Expand Down
9 changes: 9 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
json_kwargs.update(kwargs)
return cls(**json_kwargs)

@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`."""
if (config_path := path / "lit_config.json").is_file():
return cls.from_json(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")

@property
def mlp_class(self) -> Type:
# `self._mlp_class` cannot be the type to keep the config json serializable
Expand Down
2 changes: 1 addition & 1 deletion requirements-all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ zstandard # scripts/prepare_redpajama.py, scripts/prepare_starcoder.p
pandas # scripts/prepare_csv.py, scripts/prepare_starcoder.py
pyarrow # scripts/prepare_starcoder.py
# eval
git+https://github.com/EleutherAI/lm-evaluation-harness.git@master; python_version > '3.8'
git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8beaa90051fb52db77f0a529
# scripts/prepare_slimpajama.py, scripts/prepare_starcoder.py, pretrain/tinyllama.py
lightning[data] @ git+https://github.com/Lightning-AI/lightning@4e72dcc8db6a0cbe94042ddbc310340556e8fee7
2 changes: 1 addition & 1 deletion scripts/prepare_redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def prepare(
match: str = "",
) -> None:
"""Prepare the "Red Pajama" dataset. We assume tokenizer has been trained."""
config = Config.from_json(checkpoint_dir / "lit_config.json")
config = Config.from_checkpoint(checkpoint_dir)

prepare_fn = prepare_sample if sample else prepare_full
prepare_fn(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,35 @@ def test_nonexisting_name():

with pytest.raises(ValueError, match="not a supported"):
Config.from_name("foobar")


def test_from_checkpoint(tmp_path):
from lit_gpt import Config

# 1. Neither `lit_config.py` nor matching config exists.
with pytest.raises(FileNotFoundError, match="neither 'lit_config.json' nor matching config exists"):
Config.from_checkpoint(tmp_path / "non_existing_checkpoint")

# 2. If `lit_config.py` doesn't exists, but there is a matching config in `lit_gpt/config.py`.
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 2048
assert config.n_layer == 6

# 3. If only `lit_config.py` exists.
config_data = {"name": "pythia-70m", "block_size": 24, "n_layer": 2}
with open(tmp_path / "lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path)
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2

# 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config
(tmp_path / "pythia-70m").mkdir()
with open(tmp_path / "pythia-70m/lit_config.json", "w") as file:
json.dump(config_data, file)
config = Config.from_checkpoint(tmp_path / "pythia-70m")
assert config.name == "pythia-70m"
assert config.block_size == 24
assert config.n_layer == 2
3 changes: 0 additions & 3 deletions tests/test_lm_eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from lightning import Fabric


@RunIf(min_python="3.9")
@pytest.mark.xfail(raises=datasets.builder.DatasetGenerationError, strict=False) # avoid flakes
def test_run_eval(tmp_path, float_like):
from eval.lm_eval_harness import EvalHarnessBase
Expand Down Expand Up @@ -52,7 +51,6 @@ def test_run_eval(tmp_path, float_like):
}


@RunIf(min_python="3.9")
def test_eval_script(tmp_path, fake_checkpoint_dir, monkeypatch):
import eval.lm_eval_harness as module

Expand All @@ -78,7 +76,6 @@ def test_eval_script(tmp_path, fake_checkpoint_dir, monkeypatch):
assert (tmp_path / "results.json").read_text() == '{"foo": "test"}'


@RunIf(min_python="3.9")
def test_cli():
cli_path = Path(__file__).parent.parent / "eval" / "lm_eval_harness.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
Expand Down

0 comments on commit e4ec206

Please sign in to comment.