Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add phone number and zip code custom types #849

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/reference/format.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Type constraints

We can ask completions to be restricted to valid python types:

```python
from outlines import models, generate

model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = generate.format(model, int)
answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?")
print(answer)
# 67
```

The following types are currently available:

- int
- float
- bool
- datetime.date
- datetime.time
- datetime.datetime
- We also provide [custom types](types.md)
3 changes: 3 additions & 0 deletions docs/reference/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ print(result)

`generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once.

!!! Tip "Custom types"

Outlines provides [custom Pydantic types](types.md) so you do not have to write regular expressions for common types, such as phone numbers or zip codes.

## Using a JSON Schema

Expand Down
55 changes: 41 additions & 14 deletions docs/reference/types.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,49 @@
# Type constraints
# Custom types

We can ask completions to be restricted to valid python types:
Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions:

- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes.
- Using `outlines.types.PhoneNumber` will generate valid US phone numbers.

You can use these types in Pydantic schemas for JSON-structured generation:

```python
from pydantic import BaseModel

from outlines import models, generate, types


class Client(BaseModel):
name: str
phone_number: types.PhoneNumber
zip_code: types.ZipCode


model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = generate.json(model, Client)
result = generator(
"Create a client profile with the fields name, phone_number and zip_code"
)
print(result)
# name='Tommy' phone_number='129-896-5501' zip_code='50766'
```

Or simply with `outlines.generate.format`:

```python
from outlines import models, generate
from pydantic import BaseModel

from outlines import models, generate, types


model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = generate.format(model, int)
answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?")
print(answer)
# 67
generator = generate.format(model, types.PhoneNumber)
result = generator(
"Return a US Phone number: "
)
print(result)
# 334-253-2630
```

The following types are currently available:

- int
- float
- bool
- datetime.date
- datetime.time
- datetime.datetime
We plan on adding many more custom types. If you have found yourself writing regular expressions to generate fields of a given type, or if you could benefit from more specific types don't hesite to [submit a PR](https://github.com/outlines-dev/outlines/pulls) or [open an issue](https://github.com/outlines-dev/outlines/issues/new/choose).
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ nav:
- Structured generation:
- Classification: reference/choices.md
- Regex: reference/regex.md
- Type constraints: reference/types.md
- Type constraints: reference/format.md
- JSON (function calling): reference/json.md
- JSON mode: reference/json_mode.md
- Grammar: reference/cfg.md
- Custom FSM operations: reference/custom_fsm_ops.md
- Utilities:
- Serve with vLLM: reference/serve/vllm.md
- Custom types: reference/types.md
- Prompt templating: reference/prompting.md
- Outlines functions: reference/functions.md
- Models:
Expand Down
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import outlines.generate
import outlines.grammars
import outlines.models
import outlines.types
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
Expand Down
12 changes: 12 additions & 0 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
from typing import Protocol, Tuple, Type, Union

from typing_extensions import _AnnotatedAlias, get_args

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
Expand All @@ -17,6 +19,16 @@ def __call__(


def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
# If it is a custom type
if isinstance(python_type, _AnnotatedAlias):
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]

regex_str = json_schema["pattern"]
format_fn = lambda x: type_class(x)

return regex_str, format_fn

if python_type == float:

def float_format_fn(sequence: str) -> float:
Expand Down
2 changes: 2 additions & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
16 changes: 16 additions & 0 deletions outlines/types/phone_numbers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Phone number types.

We currently only support US phone numbers. We can however imagine having custom types
for each country, for instance leveraging the `phonenumbers` library.

"""
from pydantic import WithJsonSchema
from typing_extensions import Annotated

US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}"


PhoneNumber = Annotated[
str,
WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}),
]
13 changes: 13 additions & 0 deletions outlines/types/zip_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Zip code types.

We currently only support US Zip Codes.

"""
from pydantic import WithJsonSchema
from typing_extensions import Annotated

# This matches Zip and Zip+4 codes
US_ZIP_CODE = r"\d{5}(?:-\d{4})?"


ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})]
34 changes: 34 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import re

import pytest
from pydantic import BaseModel

from outlines import types
from outlines.fsm.types import python_types_to_regex


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
(types.PhoneNumber, "12", False),
(types.PhoneNumber, "(123) 123-1234", True),
(types.PhoneNumber, "123-123-1234", True),
(types.ZipCode, "12", False),
(types.ZipCode, "12345", True),
(types.ZipCode, "12345-1234", True),
],
)
def test_phone_number(custom_type, test_string, should_match):
class Model(BaseModel):
attr: custom_type

schema = Model.model_json_schema()
assert schema["properties"]["attr"]["type"] == "string"
regex_str = schema["properties"]["attr"]["pattern"]
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match

regex_str, format_fn = python_types_to_regex(custom_type)
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match
Loading