Skip to content

Commit

Permalink
Grammar: Initial Formatron regex and JSON schema implementation
Browse files Browse the repository at this point in the history
* Replace LMFE's regex and JSON schema filters with Formatron's
* Remove Outlines EBNF filter in preparation for Formatron KBNF filter
* TODO: Implement Formatron KBNF filter
  • Loading branch information
DocShotgun committed Nov 23, 2024
1 parent aa4ccd0 commit 0836a93
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 122 deletions.
156 changes: 37 additions & 119 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,20 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
build_token_enforcer_tokenizer_data,
)
from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger
from typing import List
from functools import lru_cache


class OutlinesTokenizerWrapper:
"""Wrapper for Outlines tokenizer"""

def __init__(self, tokenizer):
self.tokenizer = tokenizer
id_to_piece = self.tokenizer.get_id_to_piece_list()
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())

def convert_token_to_string(self, token):
return token

def decode(self, tokens):
s = ""
id_to_piece = self.tokenizer.get_id_to_piece_list()
for t in tokens:
s += id_to_piece[t]
return s


class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
"""Filter class for context-free grammar via outlines"""

def __init__(self, model, tokenizer, grammar):
from outlines.fsm.fsm import CFGFSM

super().__init__(model, tokenizer)

self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
self.state = self.fsm.first_state

def begin(self, prefix_str=""):
self.state = self.fsm.first_state

def feed(self, token):
self.state = self.fsm.next_state(self.state, token.item())

def next(self):
return self.fsm.allowed_token_ids(self.state), set()

def use_background_worker(self):
return True


@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)


class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""

token_sequence: List[int]

def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []

def begin(self, prefix_str: str):
self.token_sequence = []

def feed(self, token):
self.token_sequence.append(int(token[0][0]))

def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []

def use_background_worker(self):
return True
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter


def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""

_get_lmfe_tokenizer_data.cache_clear()
# TODO: Unsure if this is needed with formatron
pass


class ExLlamaV2Grammar:
Expand All @@ -117,15 +27,24 @@ def __init__(self):

def add_json_schema_filter(
self,
json_schema: dict,
schema: dict,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on a JSON schema."""

# Create the parser
try:
schema_parser = JsonSchemaParser(json_schema)
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"

# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception:
traceback.print_exc()
logger.error(
Expand All @@ -135,14 +54,10 @@ def add_json_schema_filter(

return

# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]

lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
lmfilter = create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.extend([lmfilter, prefix_filter])
self.filters.append(lmfilter)

def add_regex_filter(
self,
Expand All @@ -154,7 +69,9 @@ def add_regex_filter(

# Create the parser
try:
pattern_parser = RegexParser(pattern)
# Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception:
traceback.print_exc()
logger.error(
Expand All @@ -164,32 +81,33 @@ def add_regex_filter(

return

lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
lmfilter = create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.append(lmfilter)

def add_ebnf_filter(
def add_kbnf_filter(
self,
ebnf_string: str,
kbnf_string: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""
Add an EBNF grammar filter.
Possibly replace outlines with an in-house solution in the future.
"""
"""Adds an ExllamaV2 filter based on KBNF grammar."""

# Create the parser
try:
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
except ImportError:
# Validate KBNF and create formatter
f = FormatterBuilder()
# TODO: Implement this
except Exception:
logger.error(
"Skipping EBNF parsing because Outlines is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
"Skipping because the KBNF string couldn't be parsed. "
"Please read the above error for more information."
)

return

self.filters.append(ebnf_filter)
lmfilter = create_formatter_filter(model, tokenizer, f)

# Append the filters
self.filters.append(lmfilter)
2 changes: 1 addition & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ async def generate_gen(
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)

# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"sse-starlette",
"packaging",
"tokenizers",
"lm-format-enforcer >= 0.9.6",
"formatron",
"aiofiles",
"aiohttp",
"async_lru",
Expand All @@ -53,7 +53,6 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines",
"infinity-emb",
"sentence-transformers",
]
Expand Down

0 comments on commit 0836a93

Please sign in to comment.