Skip to content

Commit

Permalink
Merge branch 'main' into rasbt-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Apr 12, 2024
2 parents 91d394f + d6e91ee commit 0167fee
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 6 deletions.
1 change: 0 additions & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ defaults:

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
UV_HTTP_TIMEOUT: 500

jobs:
cpu-tests:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ Use, Finetune, pretrain, deploy over 20+ LLMs ([full list](tutorials/download_mo

| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
Expand Down
3 changes: 2 additions & 1 deletion extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -97,7 +98,7 @@ def setup(
executors: If using Thunder, the executors to enable.
strategy: If desired, the strategy to use.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down
26 changes: 26 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,32 @@ def norm_class(self) -> Type:
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it"
configs.append(copy)

##################
# Google CodeGemma
##################
codegemma = [
# https://huggingface.co/google/codegemma-7b-it/blob/main/config.json
dict(
name="CodeGemma-7b-it",
hf_config=dict(org="google", name="codegemma-7b-it"),
scale_embeddings=True,
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=24576,
),
]
configs.extend(codegemma)


##########################
# Stability AI FreeWilly2
Expand Down
3 changes: 2 additions & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -87,7 +88,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down
4 changes: 2 additions & 2 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:

class Phi2(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"Instruct:{prompt}\nOutput:"
return f"Instruct: {prompt}\nOutput:"


class TinyLlama(PromptStyle):
Expand Down Expand Up @@ -330,7 +330,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Phi2()
if re.search(r"tiny-llama.*chat", model_name):
return TinyLlama()
if re.search(r"Gemma.*-it", model_name):
if re.search(r"(Code)?Gemma.*-it", model_name):
return Gemma()
return Default()

Expand Down
18 changes: 17 additions & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Utility functions for training and inference."""
import inspect
import math
import pickle
import shutil
import sys
from dataclasses import asdict
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -404,6 +405,21 @@ def CLI(*args: Any, **kwargs: Any) -> Any:
return CLI(*args, **kwargs)


def capture_hparams() -> Dict[str, Any]:
"""Captures the local variables ('hyperparameters') from where this function gets called."""
caller_frame = inspect.currentframe().f_back
locals_of_caller = caller_frame.f_locals
hparams = {}
for name, value in locals_of_caller.items():
if value is None or isinstance(value, (int, float, str, bool, Path)):
hparams[name] = value
elif is_dataclass(value):
hparams[name] = asdict(value)
else:
hparams[name] = str(value)
return hparams


def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser
Expand Down
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict

import os
from contextlib import redirect_stderr
Expand All @@ -18,9 +19,11 @@
from lightning_utilities.core.imports import RequirementCache

from litgpt import GPT
from litgpt.args import TrainArgs
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
Expand Down Expand Up @@ -219,6 +222,26 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path):
assert expected.issubset(contents)


def test_capture_hparams():
integer = 1
string = "string"
boolean = True
none = None
path = Path("/path")
dataclass = TrainArgs()
other = torch.nn.Linear(1, 1)
hparams = capture_hparams()
assert hparams == {
"integer": integer,
"string": string,
"boolean": boolean,
"none": none,
"path": path,
"dataclass": asdict(dataclass),
"other": str(other),
}


def _test_function(out_dir: Path, foo: bool = False, bar: int = 1):
save_hyperparameters(_test_function, out_dir)

Expand Down
2 changes: 2 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.

| Model | Model size | Reference |
|----------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
Expand Down Expand Up @@ -84,6 +85,7 @@ garage-bAInd/Platypus2-70B
garage-bAInd/Platypus2-70B-instruct
garage-bAInd/Platypus2-7B
garage-bAInd/Stable-Platypus2-13B
google/codegemma-7b-it
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
Expand Down

0 comments on commit 0167fee

Please sign in to comment.