Skip to content

Commit

Permalink
make to_regex() work
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum committed Aug 26, 2024
1 parent ad1c225 commit f028300
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 47 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ crate-type = ["cdylib"]
anyhow = "1.0.86"
pyo3 = { version = "0.22.0", features = ["extension-module"] }
regex = "1.10.6"
serde-pyobject = "0.4.0"
serde_json = { version ="1.0.125", features = ["preserve_order"] }

[profile.release]
Expand Down
1 change: 1 addition & 0 deletions python/outlines_core/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
UUID,
WHITESPACE,
build_regex_from_schema,
to_regex,
)


Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ mod regex;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use regex::_walk_fsm;
use regex::create_fsm_index_end_to_end;
use regex::get_token_transition_keys;
use regex::get_vocabulary_transition_keys;
use regex::state_scan_tokens;
use regex::FSMInfo;
use serde_json::Value;

#[pymodule]
fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -34,6 +36,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("WHITESPACE", json_schema::WHITESPACE)?;

m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?;
m.add_function(wrap_pyfunction!(to_regex, m)?)?;

Ok(())
}
Expand All @@ -44,3 +47,11 @@ pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -
json_schema::build_regex_from_schema(&json, whitespace_pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

#[pyfunction(name = "to_regex")]
#[pyo3(signature = (json, whitespace_pattern=None))]
pub fn to_regex(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyResult<String> {
let json_value: Value = serde_pyobject::from_pyobject(json).unwrap();
json_schema::to_regex(&json_value, whitespace_pattern, &json_value)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
95 changes: 48 additions & 47 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import interegular
import pytest
from outlines_core.fsm.json_schema import ( # to_regex,
from outlines_core.fsm.json_schema import (
BOOLEAN,
DATE,
DATE_TIME,
Expand All @@ -18,6 +18,7 @@
WHITESPACE,
build_regex_from_schema,
get_schema_from_signature,
to_regex,
)
from pydantic import BaseModel, Field, constr

Expand Down Expand Up @@ -56,56 +57,56 @@ class User(BaseModel):
assert isinstance(schedule, str)


# @pytest.mark.parametrize(
# "pattern,does_match",
# [
# ({"integer": "0"}, True),
# ({"integer": "1"}, True),
# ({"integer": "-1"}, True),
# ({"integer": "01"}, False),
# ({"integer": "1.3"}, False),
# ({"integer": "t"}, False),
# ],
# )
# def test_match_integer(pattern, does_match):
# step = {"title": "Foo", "type": "integer"}
# regex = to_regex(None, step)
# assert regex == INTEGER
@pytest.mark.parametrize(
"pattern,does_match",
[
({"integer": "0"}, True),
({"integer": "1"}, True),
({"integer": "-1"}, True),
({"integer": "01"}, False),
({"integer": "1.3"}, False),
({"integer": "t"}, False),
],
)
def test_match_integer(pattern, does_match):
step = {"title": "Foo", "type": "integer"}
regex = to_regex(step)
assert regex == INTEGER

# value = pattern["integer"]
# match = re.fullmatch(regex, value)
# if does_match:
# assert match[0] == value
# assert match.span() == (0, len(value))
# else:
# assert match is None
value = pattern["integer"]
match = re.fullmatch(regex, value)
if does_match:
assert match[0] == value
assert match.span() == (0, len(value))
else:
assert match is None


# @pytest.mark.parametrize(
# "pattern,does_match",
# [
# ({"number": "1"}, True),
# ({"number": "0"}, True),
# ({"number": "01"}, False),
# ({"number": ".3"}, False),
# ({"number": "1.3"}, True),
# ({"number": "-1.3"}, True),
# ({"number": "1.3e9"}, False),
# ({"number": "1.3e+9"}, True),
# ],
# )
# def test_match_number(pattern, does_match):
# step = {"title": "Foo", "type": "number"}
# regex = to_regex(None, step)
# assert regex == NUMBER
@pytest.mark.parametrize(
"pattern,does_match",
[
({"number": "1"}, True),
({"number": "0"}, True),
({"number": "01"}, False),
({"number": ".3"}, False),
({"number": "1.3"}, True),
({"number": "-1.3"}, True),
({"number": "1.3e9"}, False),
({"number": "1.3e+9"}, True),
],
)
def test_match_number(pattern, does_match):
step = {"title": "Foo", "type": "number"}
regex = to_regex(step)
assert regex == NUMBER

# value = pattern["number"]
# match = re.fullmatch(regex, value)
# if does_match:
# assert match[0] == value
# assert match.span() == (0, len(value))
# else:
# assert match is None
value = pattern["number"]
match = re.fullmatch(regex, value)
if does_match:
assert match[0] == value
assert match.span() == (0, len(value))
else:
assert match is None


@pytest.mark.parametrize(
Expand Down

0 comments on commit f028300

Please sign in to comment.