Skip to content

Commit

Permalink
update others
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 5, 2024
1 parent d7e4feb commit b1f3037
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 6 deletions.
3 changes: 3 additions & 0 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -145,6 +146,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_adapter_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -224,6 +226,7 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# the adapter "kv cache" cannot be initialized under `inference_mode`
Expand Down
3 changes: 3 additions & 0 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -145,6 +146,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_adapter_v2_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -224,6 +226,7 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# the adapter "kv cache" cannot be initialized under `inference_mode`
Expand Down
3 changes: 3 additions & 0 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -176,6 +177,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_lora_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -255,6 +257,7 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# FSDP has issues with `inference_mode`
Expand Down
3 changes: 2 additions & 1 deletion lit_gpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from lit_gpt.args import EvalArgs, TrainArgs
from lit_gpt.data import LitDataModule, TinyLlama
from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files
from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files, save_hyperparameters


def setup(
Expand Down Expand Up @@ -269,6 +269,7 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -93,6 +93,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -116,6 +116,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -223,6 +223,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_pretrain(tmp_path, monkeypatch):

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["pretrain.py"]):
pretrain.setup(
devices=2,
model=model_config,
Expand All @@ -44,7 +44,7 @@ def test_pretrain(tmp_path, monkeypatch):
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
# the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {"lit_model.pth"}
assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "hyperparameters.yaml"}

# logs only appear on rank 0
logs = stdout.getvalue()
Expand Down

0 comments on commit b1f3037

Please sign in to comment.