Skip to content

Commit

Permalink
Merge branch 'main' into fft_resume
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jan 19, 2024
2 parents 56f024e + 0f021f3 commit e8d9a0b
Show file tree
Hide file tree
Showing 97 changed files with 3,772 additions and 1,515 deletions.
2 changes: 1 addition & 1 deletion .github/azure-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
displayName: "Image info & NVIDIA"
- script: |
pip install -r requirements-all.txt pytest pytest-rerunfailures transformers einops protobuf
pip install -r requirements-all.txt pytest pytest-rerunfailures transformers>=4.36.0 einops protobuf
displayName: 'Install dependencies'
- bash: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
- name: Install all dependencies
run: |
pip install -r requirements-all.txt pytest pytest-rerunfailures pytest-timeout transformers einops protobuf
pip install -r requirements-all.txt pytest pytest-rerunfailures pytest-timeout transformers>=4.36.0 einops protobuf
pip list
- name: Run tests without the package installed
Expand Down
66 changes: 43 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,27 @@ Hackable [implementation](lit_gpt/model.py) of state-of-the-art open-source larg

Supports the following popular model checkpoints:

| Model and usage | Model size | Reference |
|--------------------------------------------------------------------------------|------------------------------------|--------------------------------------------------------------------------------------------------|
| Meta AI [Llama 2](tutorials/download_llama_2.md) | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Stability AI StableCode | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TII UAE [Falcon](tutorials/download_falcon.md) | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
| LMSYS [LongChat](tutorials/download_longchat.md) | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| StabilityAI [StableLM](tutorials/download_stablelm.md) | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| Platypus | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| NousResearch Nous-Hermes | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) |
| Meta AI [Code Llama](tutorials/download_code_llama.md) | 7B, 13B, 34B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Microsoft Research [phi-1.5](tutorials/download_phi15.md) | 1.3B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Mistral AI [Mistral](tutorials/download_mistral.md) | 7B | [Mistral website](https://mistral.ai/) |
| [TinyLlama](tutorials/download_tinyllama.md) | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama)
| Model and usage | Model size | Reference |
|-----------------------------------------------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
| EleutherAI [Pythia](tutorials/download_pythia.md) | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| LMSYS [LongChat](tutorials/download_longchat.md) | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
| Meta AI [Code Llama](tutorials/download_code_llama.md) | 7B, 13B, 34B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Meta AI [Llama 2](tutorials/download_llama_2.md) | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Mistral AI [Mistral and Mixtral](tutorials/download_mistral.md) | 7B | [Mistral website](https://mistral.ai/) |
| Microsoft Research [Phi](tutorials/download_phi.md) | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| NousResearch Nous-Hermes | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Platypus | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Stability AI StableCode | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Stability AI [StableLM](tutorials/download_stablelm.md) | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| Stability AI [StableLM Zephyr](tutorials/download_stablelm.md) | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TII UAE [Falcon](tutorials/download_falcon.md) | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| [TinyLlama](tutorials/download_tinyllama.md) | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| Trelis [Function Calling Llama 2](tutorials/download_function_calling_llama_2.md) | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| databricks [Dolly](tutorials/download_dolly.md) | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |

This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/)**.

Expand Down Expand Up @@ -107,10 +110,13 @@ Install with all dependencies (including quantization, sentencepiece, tokenizers
pip install -r requirements-all.txt
```

**(Optional) Use Flash Attention 2 (only available in PyTorch 2.2)**
**(Optional) Use Flash Attention 2**

Flash Attention 2 will be used automatically if PyTorch 2.2 (or higher) is installed.
Currently, that requires installing PyTorch nightly, which you can get by running:

```bash
pip uninstall -y torch
pip uninstall -y torch torchvision torchaudio torchtext
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
```

Expand Down Expand Up @@ -142,7 +148,7 @@ python chat/base.py

### Run large models on smaller consumer devices

We support 4-bit quantization (as in QLoRA), (bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq, gptq.int4) and 8-bit quantization (bnb.int8) for inference by following [this guide](tutorials/quantize.md).
We support 4-bit quantization (as in QLoRA), (bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq) and 8-bit quantization (bnb.int8) for inference by following [this guide](tutorials/quantize.md).

 

Expand Down Expand Up @@ -222,7 +228,7 @@ Follow this guide to start pretraining on

## Supported datasets

Lit-GPT includes a variety of dataset preparation scripts for finetuning and pretraining. Additional information about the datasets and dataset preparation is provided in the [Preparing Datasets](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/prepare_dataset.md) tutorial.
Lit-GPT includes a variety of dataset preparation scripts for finetuning and pretraining. Additional information about the datasets and dataset preparation is provided in the [Preparing Datasets](tutorials/prepare_dataset.md) tutorial.


 
Expand Down Expand Up @@ -265,12 +271,26 @@ Don't forget to [join our Discord](https://discord.gg/VptPCZkGNa)!
- [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)
- [@EleutherAI](https://github.com/EleutherAI) for [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) and the [Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
- [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [@IST-DASLab](https://github.com/IST-DASLab) for [GPTQ](https://github.com/IST-DASLab/gptq)
- [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)
- [@tridao](https://github.com/tridao) for [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)

 

## Citation

If you use Lit-GPT in your research, please cite the following work:

```bibtex
@misc{lit-gpt-2023,
author = {Lightning AI},
title = {Lit-GPT},
howpublished = {\url{https://github.com/Lightning-AI/lit-gpt}},
year = {2023},
}
```

 

## License

Lit-GPT is released under the [Apache 2.0](https://github.com/Lightning-AI/lit-gpt/blob/main/LICENSE) license.
69 changes: 51 additions & 18 deletions chat/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import re
import sys
import time
from json import dumps
from pathlib import Path
from typing import Iterator, List, Literal, Optional, Tuple

Expand All @@ -14,12 +17,7 @@

from generate.base import next_token
from lit_gpt import GPT, Config, Tokenizer
from lit_gpt.utils import (
check_valid_checkpoint_dir,
get_default_supported_precision,
gptq_quantization,
load_checkpoint,
)
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


@torch.inference_mode()
Expand Down Expand Up @@ -87,6 +85,7 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
decoded_so_far = ""
try:
for token in token_stream:
so_far = so_far.to(device=token.device)
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = tokenizer.decode(so_far)
fabric.print(decoded_new[len(decoded_so_far) :], end="", flush=True)
Expand All @@ -100,12 +99,13 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
return tokens_generated


@torch.inference_mode()
def main(
*,
top_k: Optional[int] = 200,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-tuned-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
Expand All @@ -119,9 +119,9 @@ def main(
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
"""
precision = precision or get_default_supported_precision(training=False)

Expand All @@ -139,16 +139,10 @@ def main(

config = Config.from_json(checkpoint_dir / "lit_config.json")

if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
checkpoint_path = checkpoint_dir / "lit_model.pth"

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
with fabric.init_module(empty_init=True):
model = GPT(config)
# enable the kv cache
model.set_kv_cache(batch_size=1)
Expand Down Expand Up @@ -187,7 +181,9 @@ def main(
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens", file=sys.stderr
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()

Expand All @@ -209,6 +205,12 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
[tokenizer.token_to_id("<|USER|>")],
)
return system_prompt, stop_tokens

if re.search(r"stabilityai/stablelm-zephyr-3b", checkpoint_name):
system_prompt = "<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search(r"togethercomputer.*Chat", checkpoint_name):
system_prompt = "<human>: {prompt}\n<bot>:"
lt, gt = tokenizer.token_to_id("<"), tokenizer.token_to_id(">:")
Expand Down Expand Up @@ -255,6 +257,32 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
)
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search("Llama-2-7b-chat-hf-function-calling-v2", checkpoint_name):
# Has to be before the llama config
b_func, e_func = "<FUNCTIONS>", "</FUNCTIONS>\n\n"
b_inst, e_inst = "[INST]", "[/INST]"
b_sys, e_sys = "<<SYS>>\n", "\n<</SYS>>\n\n"
# This is an example for how to format functions for the model
function_metadata = {
"function": "search_bing",
"description": (
"Search the web for content on Bing. This allows users to search online/the internet/the web for"
" content."
),
"arguments": [{"name": "query", "type": "string", "description": "The search query string"}],
}

system_prompt = (
"You are a helpful, respectful and honest assistant. Always answer as helpfully as"
"possible. Your only response should be JSON formatted functions"
)
# replace the curly braces with double curly braces to escape them
function_list = dumps(function_metadata).replace("{", "{{").replace("}", "}}")
system_prompt = f"{b_func}{function_list.strip()}{e_func}{b_inst}{b_sys}{system_prompt.strip()}{e_sys}{'{prompt}'}{e_inst}\n\n"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search("Llama-2.*-chat", checkpoint_name):
b_inst, e_inst = "[INST]", "[/INST]"
b_sys, e_sys = "<<SYS>>\n", "\n<</SYS>>\n\n"
Expand Down Expand Up @@ -304,7 +332,7 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search("phi", checkpoint_name):
if re.search("phi-1", checkpoint_name):
system_prompt = "{prompt}\n\nAnswer:"

stop_tokens = (
Expand All @@ -317,6 +345,11 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
)
return system_prompt, stop_tokens

if re.search("phi-2", checkpoint_name):
system_prompt = "Instruct:{prompt}\nOutput:"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search(r"TinyLlama.*Chat", checkpoint_name):
system_prompt = (
"<|system|>\n"
Expand Down
Loading

0 comments on commit e8d9a0b

Please sign in to comment.