-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into carmocca/fsdp-regular-trasnform
- Loading branch information
Showing
9 changed files
with
167 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
## Minimal LitGPT Generate Examples in Python | ||
|
||
|
||
|
||
The scripts in this folder provide minimal examples showing how to use LitGPT from within Python without the CLI. | ||
|
||
- `generate.py` is a minimal script that uses the `main` function from LitGPT's `generate` utilities | ||
- `generate-step-by-step.py` is a lower-level script using LitGPT utility functions directly instead of relying on the `main` function menntioned above. | ||
|
||
Assuming you downloaded the checkpoint files via | ||
|
||
```bash | ||
litgpt download --repo_id EleutherAI/pythia-1b | ||
``` | ||
|
||
you can run the scripts as follows: | ||
|
||
|
||
```bash | ||
python generate-step-by-step.py | ||
``` | ||
|
||
or | ||
|
||
```bash | ||
python generate.py | ||
``` | ||
|
||
|
||
|
84 changes: 84 additions & 0 deletions
84
tutorials/examples/minimal-generate-scripts/generate-step-by-step.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
|
||
from pathlib import Path | ||
|
||
import lightning as L | ||
import torch | ||
|
||
from litgpt.prompts import PromptStyle | ||
from litgpt.tokenizer import Tokenizer | ||
from litgpt.utils import load_checkpoint, get_default_supported_precision | ||
from litgpt.generate.base import generate | ||
from litgpt.model import GPT | ||
from litgpt.config import Config | ||
|
||
|
||
def use_model(): | ||
|
||
################### | ||
# Load model | ||
################### | ||
|
||
# run `litgpt download --repo_id EleutherAI/pythia-1b` to download the checkpoint first | ||
checkpoint_dir = Path("checkpoints") / "EleutherAI" / "pythia-1b" | ||
config = Config.from_file(checkpoint_dir / "model_config.yaml") | ||
|
||
precision = get_default_supported_precision(training=False) | ||
device = torch.device("cuda") | ||
|
||
fabric = L.Fabric( | ||
accelerator=device.type, | ||
devices=1, | ||
precision=precision, | ||
) | ||
|
||
checkpoint_path = checkpoint_dir / "lit_model.pth" | ||
tokenizer = Tokenizer(checkpoint_dir) | ||
|
||
prompt_style = PromptStyle.from_config(config) | ||
|
||
with fabric.init_module(empty_init=True): | ||
model = GPT(config) | ||
with fabric.init_tensor(): | ||
model.set_kv_cache(batch_size=1) | ||
|
||
model.eval() | ||
model = fabric.setup_module(model) | ||
load_checkpoint(fabric, model, checkpoint_path) | ||
|
||
device = fabric.device | ||
|
||
################### | ||
# Predict | ||
################### | ||
|
||
prompt = "What do Llamas eat?" | ||
max_new_tokens = 50 | ||
|
||
prompt = prompt_style.apply(prompt) | ||
encoded = tokenizer.encode(prompt, device=device) | ||
|
||
prompt_length = encoded.size(0) | ||
max_returned_tokens = prompt_length + max_new_tokens | ||
|
||
torch.manual_seed(123) | ||
|
||
y = generate( | ||
model, | ||
encoded, | ||
max_returned_tokens, | ||
temperature=0.5, | ||
top_k=200, | ||
top_p=1.0, | ||
eos_id=tokenizer.eos_id | ||
) | ||
|
||
for block in model.transformer.h: | ||
block.attn.kv_cache.reset_parameters() | ||
|
||
decoded_output = tokenizer.decode(y) | ||
print(decoded_output) | ||
|
||
|
||
if __name__ == "__main__": | ||
use_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
|
||
from pathlib import Path | ||
import torch | ||
from litgpt.generate.base import main | ||
from litgpt.utils import get_default_supported_precision | ||
|
||
|
||
def use_model(): | ||
|
||
# run `litgpt download --repo_id EleutherAI/pythia-1b` to download the checkpoint first | ||
checkpoint_dir = Path("checkpoints") / "EleutherAI" / "pythia-1b" | ||
|
||
torch.manual_seed(123) | ||
|
||
main( | ||
prompt="What food do llamas eat?", | ||
max_new_tokens=50, | ||
temperature=0.5, | ||
top_k=200, | ||
top_p=1.0, | ||
checkpoint_dir=checkpoint_dir, | ||
precision=get_default_supported_precision(training=False), | ||
compile=False | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
use_model() |