Skip to content

Commit

Permalink
Pipeline example (#43)
Browse files Browse the repository at this point in the history
* added hf pipeline example

* pipeline example added + doc updated
  • Loading branch information
arinaruck authored Jun 9, 2024
1 parent 308066c commit 8f8717a
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,48 @@ if __name__ == "__main__":
'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
'This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
"""
```

Alternatively, you can use `transformers-cfg` to perform grammar-constrained decoding with huggingface pipeline.

<details>
<summary>Click here to see an example, or check it out in `examples/pipeline_json.py` </summary>

```python
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Load grammar
with open(f"examples/grammars/json.ebnf", "r") as file:
grammar_str = file.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Initialize pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
max_length=50,
batch_size=2,
)

generations = pipe(
[
"This is a valid json string for http request: ",
"This is a valid json string for shopping cart: ",
],
do_sample=False,
logits_processor=[grammar_processor],
)
```
</details>


## 💡Why should I use transformers-CFG?

Expand Down
74 changes: 74 additions & 0 deletions examples/pipeline_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import argparse
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor


def parse_args():
parser = argparse.ArgumentParser(
description="Generate json strings with huggingface pipelining"
)
parser.add_argument(
"--model-id",
type=str,
default="/dlabdata1/llm_hub/Mistral-7B-v0.1",
help="Model ID",
)
parser.add_argument("--device", type=str, help="Device to put the model on")
return parser.parse_args()


def main():
args = parse_args()
model_id = args.model_id

# Detect if GPU is available, otherwise use CPU
device = torch.device(
args.device or ("cuda" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Load grammar
with open(f"examples/grammars/json.ebnf", "r") as file:
grammar_str = file.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Initialize pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
max_length=50,
batch_size=2,
)
# # outputs = pipe("This is a valid json string for http request:", do_sample=False, max_length=50)
generations = pipe(
[
"This is a valid json string for http request: ",
"This is a valid json string for shopping cart: ",
],
do_sample=False,
logits_processor=[grammar_processor],
)

print(generations)

"""
This is a valid json string for http request: {"name":"John","age":30,"city":"New York"}
This is a valid json string for shopping cart: {"items":[{"id":"1","quantity":"1"},{"id":"2","quantity":"2"}]}
"""


if __name__ == "__main__":
main()

0 comments on commit 8f8717a

Please sign in to comment.