Skip to content

Commit

Permalink
support for mixing multiple tokenizers/models
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Sep 19, 2023
1 parent 65db2cf commit afc641d
Show file tree
Hide file tree
Showing 21 changed files with 286 additions and 157 deletions.
6 changes: 3 additions & 3 deletions src/lmql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lmql.runtime.output_writer import headless, printing, silent, stream
from lmql.runtime.interpreter import LMQLResult

from lmql.models.model import model, LMQLModel
from lmql.models.model import model, LMQLModelDescriptor
from lmql.runtime.loop import main
import lmql.runtime.decorators as decorators

Expand All @@ -44,9 +44,9 @@ def autoconnect():
def set_backend(backend):
model_registry.backend_configuration = backend

def set_default_model(model: Union[str, LMQLModel]):
def set_default_model(model: Union[str, LMQLModelDescriptor]):
"""
Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are used.
Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are specified.
This applies globally in the current process.
"""
Expand Down
9 changes: 6 additions & 3 deletions src/lmql/lib/actions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import lmql
import ast

from .prompts.wiki_prompt import EXAMPLES as WIKI_EXAMPLES
from .prompts.inline_use import INLINE_USE_PROMPT
from .prompts.inline_code import INLINE_CODE_PROMPT

import warnings

from dataclasses import dataclass

async def wiki(q: str, lookup = None):
Expand Down Expand Up @@ -95,6 +97,8 @@ async def fct_call(fcts):
print("unknown action", [action], list(action_fcts.keys()))
" Unknown action: {action}{DELIMITER_END}"
result = ""
if CALL.lstrip().startswith("<<"):
warnings.warn("Detected a rare failure case where LMQL Actions function calling does not work as intended. Please share your query with the LMQL developers, so we can fix this.", RuntimeWarning)
return "(error)"
else:
try:
Expand All @@ -113,8 +117,7 @@ async def inline_segment(fcts):
if not SEGMENT.endswith(DELIMITER):
return SEGMENT
else:
"[CALL]" where fct_call(CALL, fcts) and len(TOKENS(CALL)) > 0
result = CALL.split("|", 1)[1]
"[CALL: fct_call(fcts)]" where len(TOKENS(CALL)) > 0
return SEGMENT[:-len(DELIMITER)] + CALL + DELIMITER_END
'''

Expand Down
2 changes: 1 addition & 1 deletion src/lmql/lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def extract_json(s):
@lmql.query
async def single_shot_as_type(s, ty, model="chatgpt"):
'''lmql
argmax(openai_chunksize=1024)
argmax(chunksize=1024)
schema_description = type_schema_description(ty)
"Provided a data schema of the following schema: {schema_description}\n"
"Translate the following into a JSON payload: {s}\n"
Expand Down
6 changes: 5 additions & 1 deletion src/lmql/models/lmtp/backends/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ def __init__(self, seed=None, vocab=None, **kwargs):
self.seed = seed
self.kwargs = kwargs

if kwargs.get("verbose", False):
print("['random' model using seed {}]".format(seed))

if vocab is not None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(vocab)
print("['random' model using tokenizer {}]".format(tokenizer))
if kwargs.get("verbose", False):
print("['random' model using tokenizer {}]".format(tokenizer))
self._eos_token_id = tokenizer.eos_token_id
self._vocab_size = tokenizer.vocab_size
else:
Expand Down
14 changes: 5 additions & 9 deletions src/lmql/models/lmtp/lmtp_dcmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import lmql.utils.nputil as nputil
import lmql.runtime.masks as masks
from lmql.runtime.token_distribution import TokenDistribution
from lmql.models.model import LMQLModel

from typing import Any, List, Union
from typing import Any, List, Union, Type
import random
import sys
import traceback
Expand Down Expand Up @@ -467,11 +468,11 @@ def __del__(self):
if "no current event loop" in str(e): pass
else: raise e

def __call__(self):
def __call__(self) -> LMQLModel:
# reference to factory instance
this = self

class LMTPDcModelCls:
class LMTPAdapterModel(LMQLModel):
def __init__(self) -> None:
self.model_identifier = this.model_identifier
self.served_model = None
Expand All @@ -491,8 +492,6 @@ def get_dclib_model(self):
bos_token_id = self.get_tokenizer().bos_token_id
eos_token_id = self.get_tokenizer().eos_token_id

dc.set_dclib_tokenizer(self.get_tokenizer())

inprocess_client_constructor = None

if this.inprocess:
Expand All @@ -512,8 +511,5 @@ async def detokenize(self, input_ids):

def sync_tokenize(self, text):
return self.get_tokenizer()(text)["input_ids"]

def report_metrics(self, metrics):
pass

return LMTPDcModelCls()
return LMTPAdapterModel()
93 changes: 75 additions & 18 deletions src/lmql/models/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from lmql.runtime.tokenizer import LMQLTokenizer
from lmql.runtime.dclib.dclib_model import DcModel
from typing import Any
from abc import ABC, abstractmethod
from lmql.models.lmtp.utils import rename_model_args

import warnings

class LMQLModel:
class LMQLModelDescriptor:
"""
Descriptor base class for results of lmql.model(...) calls.
The actual model runs either remotely (lmql serve-model), in-process via self.model or
via an API (e.g. OpenAI).
is accessed via API (e.g. OpenAI).
Use LMQLModelRegistry.get to resolve a descriptor into a usable LMQLModel.
"""

def __init__(self, model_identifier, model=None, **kwargs):
Expand All @@ -20,13 +27,57 @@ def __repr__(self) -> str:
return str(self)

def __str__(self):
return "<LMQLModel: {}>".format(self.model_identifier)
return "<LMQLModelDescriptor: {}>".format(self.model_identifier)

LMQLModel.inprocess_instances = {}
LMQLModelDescriptor.inprocess_instances = {}

class LMQLModel(ABC):
"""
Abstract base class for models (interface to integrate LMTP or OpenAI API)
"""

@abstractmethod
def get_tokenizer(self) -> LMQLTokenizer:
"""
Returns the tokenizer used by this model.
"""
raise NotImplementedError()

@abstractmethod
def get_dclib_model(self) -> DcModel:
"""
Returns the dclib DcModel handle to use this model.
This handle is used for token generation and decoding via
its methods like `argmax`, `sample` and `score`.
"""
raise NotImplementedError()

@abstractmethod
async def tokenize(self, text: str) -> Any:
"""
Tokenizes the given text and returns the tokenized input_ids in
the format expected by the model.
"""
raise NotImplementedError()

@abstractmethod
async def detokenize(self, input_ids: Any) -> str:
"""
Detokenizes the given input_ids and returns the detokenized text.
"""
raise NotImplementedError()

@abstractmethod
def sync_tokenize(self, text):
"""
Synchroneous version of `tokenize`.
"""
raise NotImplementedError()

def inprocess(model_name, use_existing_configuration=False, **kwargs):
"""
Loads a 'transformers' model in-process and returns an LMQLModel object
Loads a 'transformers' model in-process and returns an LMQLModelDescriptor object
to use this in-process model in LMQL.
This is useful when you don't want to spawn a separate 'lmql serve-model' process.
Expand Down Expand Up @@ -67,23 +118,29 @@ def inprocess(model_name, use_existing_configuration=False, **kwargs):
else:
cmdline_args += f"--{k} {v} "

if cmdline_args in LMQLModel.inprocess_instances.keys():
warnings.warn("info: reusing existing in-process model.")
model = LMQLModel.inprocess_instances[cmdline_args]
return LMQLModel(model_name, model=model)
if cmdline_args in LMQLModelDescriptor.inprocess_instances.keys():
model = LMQLModelDescriptor.inprocess_instances[cmdline_args]
return LMQLModelDescriptor(model_name, model=model)

if use_existing_configuration:
# find existing match for model_name only
for cmdargs, p in LMQLModel.inprocess_instances.items():
if cmdargs.split(" ")[0] == model_name:
return LMQLModel(model_name, model=p)

kwargs["inprocess"] = True
model = lmtp_model(model_name, **kwargs)
LMQLModel.inprocess_instances[cmdline_args] = model
return LMQLModel(model_name, model=model)
LMQLModelDescriptor.inprocess_instances[cmdline_args] = model
return LMQLModelDescriptor(model_name, model=model)

def model(model_identifier, **kwargs):
"""
Constructs an LMQL model descriptor object to be used in
a `from` clause or as `model=<MODEL>` argument to @lmql.query(...).
Examples:
lmql.model("openai/gpt-3.5-turbo-instruct") # OpenAI API model
lmql.model("random", seed=123) # randomly sampling model
lmql.model("llama.cpp:<YOUR_WEIGHTS>.bin") # llama.cpp model
lmql.model("local:gpt2") # load a `transformers` model in process
lmql.model("local:gpt2", cuda=True, load_in_4bit=True) # load a `transformers` model in process with additional arguments
"""
# handle inprocess models
is_inprocess = kwargs.pop("inprocess", False) or model_identifier.startswith("local:")
if is_inprocess and model_identifier.startswith("local:"):
Expand All @@ -92,4 +149,4 @@ def model(model_identifier, **kwargs):
if is_inprocess:
return inprocess(model_identifier, **kwargs)
else:
return LMQLModel(model_identifier, **kwargs)
return LMQLModelDescriptor(model_identifier, **kwargs)
40 changes: 22 additions & 18 deletions src/lmql/ops/token_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmql.runtime.caching import cachefile
from lmql.runtime.tokenizer import get_vocab
from lmql.ops.regex import Regex
from lmql.runtime.context import get_tokenizer

class VocabularyMatcher:
"""
Expand All @@ -30,14 +31,15 @@ def __init__(self, tokenizer, model_identifier):

self.stats = Stats("VocabularyMatcher")
self.disk_cached = 0
self.cache = {}

@property
def eos_token_id(self):
return self.tokenizer.eos_token_id

@staticmethod
def init(tokenizer):
if VocabularyMatcher._instance is not None:
if tokenizer.name in VocabularyMatcher._instances:
return

# first try to load pickled matcher from cache (faster)
Expand All @@ -50,25 +52,28 @@ def init(tokenizer):

try:
with cachefile(matcher_path, "rb") as f:
VocabularyMatcher._instance = pickle.load(f)
VocabularyMatcher._instance.stats = Stats("VocabularyMatcher")
_instance = pickle.load(f)
_instance.stats = Stats("VocabularyMatcher")
except:
VocabularyMatcher._instance = VocabularyMatcher(tokenizer, tokenizer.model_identifier)
_instance = VocabularyMatcher(tokenizer, tokenizer.model_identifier)

try:
with cachefile(cache_path, "rb") as f:
try:
import time
s = time.time()
VocabularyMatcher.cache = pickle.load(f)
VocabularyMatcher._instance.disk_cached = len(VocabularyMatcher.cache)
_instance.cache = pickle.load(f)
_instance.disk_cached = len(_instance.cache)
except:
warnings.warn("Failed to load token mask cache from {}. If the cache is corrupted, please delete it.".format(cache_path))
except:
# no cache file
pass

atexit.register(lambda: VocabularyMatcher._instance.save())
# save in instance pool
VocabularyMatcher._instances[tokenizer.name] = _instance
# save on exit
atexit.register(lambda: _instance.save())

def save(self):
# save cache to disk
Expand All @@ -93,28 +98,28 @@ def is_cached(k):
return False

with cachefile(cache_path, "wb") as f:
pickle.dump({k: v for k, v in VocabularyMatcher.cache.items() if is_cached(k)}, f)
pickle.dump({k: v for k, v in self.cache.items() if is_cached(k)}, f)

@staticmethod
def instance():
if VocabularyMatcher._instance is None:
tokenizer = get_tokenizer()
if not tokenizer.name in VocabularyMatcher._instances:
raise Exception("VocabularyMatcher not initialized.")
return VocabularyMatcher._instance
return VocabularyMatcher._instances[tokenizer.name]

@staticmethod
def ensure_ready():
VocabularyMatcher.instance()

@staticmethod
def with_cache(keys, provider):
def with_cache(self, keys, provider):
keys = [k for k in keys if k is not None]
for k in keys:
if k in VocabularyMatcher.cache.keys():
return VocabularyMatcher.cache[k]
if k in self.cache.keys():
return self.cache[k]
else:
result = provider()
for k in keys:
VocabularyMatcher.cache[k] = result
self.cache[k] = result
return result

def mask_cache_name(self, tokens=None, regex=None, minus=None, prefix=None, exact=None, charlen=None, name=None):
Expand Down Expand Up @@ -145,7 +150,7 @@ def do_make_mask():

return mask

return VocabularyMatcher.with_cache(cache_keys, do_make_mask)
return self.with_cache(cache_keys, do_make_mask)

def _make_mask_from_regex(self, regex, prefix=False):
regex = regex.replace(" ", self.space_repr)
Expand Down Expand Up @@ -257,8 +262,7 @@ def tstr(t):
", ".join([t for t in sorted(list(tokens))]) + ("..." if truncated else "")
)

VocabularyMatcher._instance = None
VocabularyMatcher.cache = {}
VocabularyMatcher._instances = {}

def has_tail(mask):
if mask is None: return False
Expand Down
Loading

0 comments on commit afc641d

Please sign in to comment.