Skip to content

Commit

Permalink
Pickle support exploration
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 21, 2024
1 parent 9ddf5e7 commit cfa3070
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 28 deletions.
6 changes: 6 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,9 @@ class Index:
def get_initial_state(self) -> int:
"""Returns the ID of the initial state of the input FSM automata."""
...
def __setstate__(self, state: List[int]):
...
def __getstate__(self) -> List[int]:
...
def __getnewargs__(self) -> "Index":
...
5 changes: 3 additions & 2 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::prelude::{State, TransitionKey};
use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};

#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct FSMInfo {
pub(crate) initial: State,
pub(crate) finals: HashSet<State>,
Expand All @@ -32,7 +33,7 @@ impl FSMInfo {
}
}

#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Index {
initial: u32,
finals: HashSet<u32>,
Expand Down
170 changes: 145 additions & 25 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ use crate::regex::get_token_transition_keys;
use crate::regex::get_vocabulary_transition_keys;
use crate::regex::state_scan_tokens;
use crate::regex::walk_fsm;
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use pyo3::types::{PyBytes, PyDict};
use pyo3::{wrap_pyfunction, Python};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};

#[pyclass(name = "FSMInfo")]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
#[pyclass(module = "outlines_core.fsm.outlines_core_rs", name = "FSMInfo")]
pub struct PyFSMInfo {
#[pyo3(get)]
initial: State,
Expand All @@ -39,8 +41,8 @@ impl From<FSMInfo> for PyFSMInfo {
}

// FIXME: could be costly, confirm if FSMInfo will actually be part of the interface
impl From<&PyFSMInfo> for FSMInfo {
fn from(fsm_info: &PyFSMInfo) -> Self {
impl From<PyFSMInfo> for FSMInfo {
fn from(fsm_info: PyFSMInfo) -> Self {
FSMInfo {
initial: fsm_info.initial,
finals: fsm_info.finals.clone(),
Expand Down Expand Up @@ -70,43 +72,132 @@ impl PyFSMInfo {
)
.into()
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self)
.map_err(|e| PyException::new_err(format!("Failed to pickle FSMInfo: {}", e)))?;
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
*self = serde_json::from_slice(s).map_err(|e| {
PyException::new_err(format!("Failed to unpickle FSMInfo: {}", e))
})?;
Ok(())
}
Err(e) => Err(e),
}
}

fn __getnewargs__(
&self,
) -> PyResult<(
State,
HashSet<State>,
HashMap<(State, TransitionKey), State>,
TransitionKey,
HashMap<String, TransitionKey>,
)> {
Ok((
self.initial,
self.finals.clone(),
self.transitions.clone(),
self.alphabet_anything_value,
self.alphabet_symbol_mapping.clone(),
))
}
}

#[pyclass(name = "Index")]
pub struct PyIndex(Index);
#[derive(Serialize, Deserialize, Debug, Clone)]
#[pyclass(module = "outlines_core.fsm.outlines_core_rs", name = "Index")]
pub struct PyIndex {
#[pyo3(get)]
fsm_info: PyFSMInfo,
#[pyo3(get)]
vocabulary: PyVocabulary,
#[pyo3(get)]
frozen_tokens: HashSet<String>,
eos_token_id: u32,
inner: Option<Index>,
}

#[pymethods]
impl PyIndex {
#[new]
fn new(
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
fsm_info: PyFSMInfo,
vocabulary: PyVocabulary,
eos_token_id: u32,
frozen_tokens: HashSet<String>,
) -> PyResult<Self> {
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
.map(PyIndex)
.map_err(Into::into)
) -> Self {
Self {
fsm_info,
vocabulary,
eos_token_id,
frozen_tokens,
inner: None,
}
}

fn get_allowed_tokens(&self, state: u32) -> Option<Vec<u32>> {
self.0.allowed_tokens(state)
fn build(&mut self) -> PyResult<()> {
let fsm_info: FSMInfo = self.fsm_info.clone().into();
let index = Index::new(
&fsm_info, &self.vocabulary.0, self.eos_token_id, self.frozen_tokens.clone()
);
self.inner = Some(index?);
Ok(())
}

fn get_next_state(&self, state: u32, token_id: u32) -> Option<u32> {
self.0.next_state(state, token_id)
fn get_allowed_tokens(&self, state: u32) -> Option<Vec<u32>> {
match &self.inner {
Some(i) => i.allowed_tokens(state),
None => None,
}
}

fn is_final_state(&self, state: u32) -> bool {
self.0.is_final(state)
// fn get_next_state(&self, state: u32, token_id: u32) -> Option<u32> {
// self.0.next_state(state, token_id)
// }

// fn is_final_state(&self, state: u32) -> bool {
// self.0.is_final(state)
// }

// fn get_transitions(&self) -> HashMap<u32, HashMap<u32, u32>> {
// self.0.transitions().clone()
// }

// fn get_initial_state(&self) -> u32 {
// self.0.initial()
// }

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self)
.map_err(|e| PyException::new_err(format!("Failed to pickle Index: {}", e)))?;
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn get_transitions(&self) -> HashMap<u32, HashMap<u32, u32>> {
self.0.transitions().clone()
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
*self = serde_json::from_slice(s).map_err(|e| {
PyException::new_err(format!("Failed to unpickle Index: {}", e))
})?;
Ok(())
}
Err(e) => Err(e),
}
}

fn get_initial_state(&self) -> u32 {
self.0.initial()
fn __getnewargs__(&self) -> PyResult<(PyFSMInfo, PyVocabulary, u32, HashSet<String>)> {
Ok((
PyFSMInfo::default(),
PyVocabulary::default(),
0,
HashSet::default(),
))
}
}

Expand Down Expand Up @@ -256,7 +347,8 @@ pub fn create_fsm_index_end_to_end_py<'py>(
Ok(states_to_token_subsets)
}

#[pyclass(name = "Vocabulary")]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
#[pyclass(module = "outlines_core.fsm.outlines_core_rs", name = "Vocabulary")]
pub struct PyVocabulary(Vocabulary);

#[pymethods]
Expand All @@ -266,13 +358,41 @@ impl PyVocabulary {
PyVocabulary(Vocabulary::from(map))
}

#[new]
#[pyo3(signature = (eos_token_id=None))]
fn new(eos_token_id: Option<u32>) -> Self {
PyVocabulary(Vocabulary::new(eos_token_id))
}

fn __repr__(&self) -> String {
format!("{:#?}", self.0)
}

fn __str__(&self) -> String {
format!("{}", self.0)
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self)
.map_err(|e| PyException::new_err(format!("Failed to pickle Vocabulary: {}", e)))?;
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&[u8]>(py) {
Ok(s) => {
*self = serde_json::from_slice(s).map_err(|e| {
PyException::new_err(format!("Failed to unpickle Vocabulary: {}", e))
})?;
Ok(())
}
Err(e) => Err(e),
}
}

fn __getnewargs__(&self) -> PyResult<(Option<u32>,)> {
Ok((None,))
}
}

#[pymodule]
Expand Down
4 changes: 3 additions & 1 deletion src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};
use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

Expand All @@ -25,7 +26,8 @@ mod processor;
/// .insert("2", 2)
/// .insert("0", 3);
/// ```
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]

pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
Expand Down
Empty file added tests/integration/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions tests/integration/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
def test_pickle_support():
import pickle
import interegular
from outlines_core.fsm.outlines_core_rs import Index, Vocabulary

from outlines_core.fsm.regex import (
create_fsm_index_tokenizer,
make_byte_level_fsm,
make_deterministic_fsm,
reduced_vocabulary,
)

class MockTokenizer:
vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4}
special_tokens = {"eos"}
eos_token_id = 4

def convert_token_to_string(self, token):
return token

tokenizer = MockTokenizer()

pattern = r"z[ab]z"
regex_pattern = interegular.parse_pattern(pattern)
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
tokens_to_token_ids, _ = reduced_vocabulary(tokenizer)

vocabulary = Vocabulary.from_dict(tokens_to_token_ids)
fsm_info = regex_fsm.fsm_info

index = Index(fsm_info, vocabulary, 4, frozenset())

pickled = pickle.dumps(index)
restored = pickle.loads(pickled)

# assert index.initial == restored.initial
# assert index.finals == restored.finals
# assert index.states_to_token_subsets == restored.states_to_token_subsets
# assert index.eos_token_id == restored.eos_token_id

0 comments on commit cfa3070

Please sign in to comment.