Skip to content

Commit

Permalink
wip blob tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Nov 16, 2023
1 parent 797a95d commit 8145eca
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 111 deletions.
1 change: 1 addition & 0 deletions src/lmql/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .scoring import ScoringResult
from .serve import serve
from inspect import *
from .blobs import Blob

from lmql.runtime.tokenizer import tokenizer
from lmql.runtime.loop import run_in_loop
Expand Down
36 changes: 36 additions & 0 deletions src/lmql/api/blobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from weakref import WeakValueDictionary
import re

blob_store = WeakValueDictionary()

class Blob:
"""
Represents multi-modal data as part of an LMQL prompt.
"""
def __init__(self, data):
self.data = data
self.id = str(hash(data))

blob_store[self.id] = self

@staticmethod
def resolve(id):
return blob_store.get(id)

def __str__(self):
return f"<Blob {[str(self.data)]}>"

def __repr__(self):
return str(self)

@staticmethod
def decode(text):
print("decode", [text])
pattern = r"<lmql:media([^>]*)\>"
# split text by pattern matches, and replace each match with the resolved blob
parts = re.split(pattern, text)
for i in range(1, len(parts), 2):
id = parts[i].split("id='")[1].split("'")[0]
print([id])
parts[i] = Blob.resolve(id)
return text
2 changes: 1 addition & 1 deletion src/lmql/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def score(self, prompt: str, values: Union[str, List[str]], **kwargs) -> S
try:
dcmodel = self.adapter.get_dclib_model()
with traced(str(self) + ".score"):
with dc.Context(self.adapter.get_tokenizer(), dcmodel.truncation_threshold):
with dc.Context(None, self.adapter.get_tokenizer(), dcmodel.truncation_threshold):
return await dc_score(dcmodel, prompt, values, **kwargs)
finally:
dcmodel.close()
Expand Down
3 changes: 2 additions & 1 deletion src/lmql/runtime/bopenai/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio

from lmql.runtime.stats import Stats
from lmql.api.blobs import Blob
from lmql.runtime.tracing import Tracer
from lmql.models.model_info import model_info

Expand Down Expand Up @@ -397,7 +398,7 @@ async def completion_api(**kwargs):
num_prompts = len(kwargs["prompt"])
timeout = kwargs.pop("timeout", 1.5)
tracer = kwargs.pop("tracer", None)

max_tokens = kwargs.get("max_tokens")
# if no token limit is set, use 1024 as a generous chunk size
# (completion models require max_tokens to be set)
Expand Down
12 changes: 10 additions & 2 deletions src/lmql/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def pop_context():
_context.set(_context.get()[:-1])

class Context:
def __init__(self, tokenizer, truncation_threshold=-3e38):
def __init__(self, interpreter, tokenizer=None, truncation_threshold=-3e38):
self.interpreter = interpreter
self.tokenizer = tokenizer
self.truncation_threshold = truncation_threshold

Expand All @@ -59,4 +60,11 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
pop_context()
return False
return False

@classmethod
def get(cls):
ctx = _context.get()
if len(ctx) == 0:
return None
return ctx[-1]
9 changes: 9 additions & 0 deletions src/lmql/runtime/formatting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Formatting of values in a prompt context.
"""
from lmql.api.blobs import Blob
from lmql.runtime.context import Context

def unescape(s):
return str(s).replace("[", "[[").replace("]", "]]")
Expand All @@ -18,6 +20,10 @@ def is_chat_list(l):
return True

def format_chat(chat):
"""
Formats a list of dicts representing an OpenAI Chat model
input into an LMQL-compliant prompt string using <lmql:ROLE/> tags.
"""
qstring = ""

for m in chat:
Expand All @@ -34,5 +40,8 @@ def format(s):
"""
if is_chat_list(s):
return format_chat(s)

if isinstance(s, Blob):
return tag("media id='" + s.id + "'")

return unescape(s)
218 changes: 111 additions & 107 deletions src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def process_query_string(self, s: str, first=False):
# check if this is the first token in the prompt and it is a tag
first_tag = s.startswith("<lmql:") and first
# replace <lmql:ROLE/> with ((ROLE)):
s = re.sub(r"<lmql:(.*?)\/>", r"\n((\1)):", s)
s = re.sub(r"<lmql:([^ :]*)\/>", r"\n((\1)):", s)
# strip off leading newline if it was added due to a tag
if first_tag: s = s[1:]
s = unescape_qstring(s)
Expand Down Expand Up @@ -950,126 +950,130 @@ async def run(self, fct, *args, **kwargs):
if "input" in kwargs.keys() and kwargs["input"] == input:
kwargs["input"] = self.input

# prepare initial program state
context = LMQLContext(self, None, "")
query_head = InterpretationHead(fct, context, args, kwargs)
self.root_state = PromptState(interpreter=self, subinterpreters={},
variable=None, prompt="", stmt_buffer=[],
query_head=query_head, program_state=context.program_state,
query_args=None, variable_args=None,
recurring_variable_counter={}, variable_offset=0,
valid=None, final=None, mask=None,
stopping_phrases=None, where=None,
tail=None)
self.root_state = await self.advance(self.root_state)

async def debug_out(decoder_step):
if PromptInterpreter.main != self:
return
if _DCLibDebugPrinter.printer is not None and dc.DecoderSequence.graph is not None:
data = await dc.DecoderSequence.graph.json(diff=True)
data = replace_inf_nan_with_str(data)
_DCLibDebugPrinter.printer.add_decoder_state(data)
self.dcmodel.report_stats(_DCLibDebugPrinter.printer, decoder_step)

# handle queries w/o any TemplateVariables
if self.root_state.query_head.result is not None:
with Context(self.model.get_tokenizer(), self.dcmodel.truncation_threshold):
with Context(self) as ctx:
# prepare initial program state
context = LMQLContext(self, None, "")
query_head = InterpretationHead(fct, context, args, kwargs)
self.root_state = PromptState(interpreter=self, subinterpreters={},
variable=None, prompt="", stmt_buffer=[],
query_head=query_head, program_state=context.program_state,
query_args=None, variable_args=None,
recurring_variable_counter={}, variable_offset=0,
valid=None, final=None, mask=None,
stopping_phrases=None, where=None,
tail=None)
self.root_state = await self.advance(self.root_state)

# update context
ctx.tokenizer = self.model.get_tokenizer()
ctx.truncation_threshold = self.dcmodel.truncation_threshold

async def debug_out(decoder_step):
if PromptInterpreter.main != self:
return
if _DCLibDebugPrinter.printer is not None and dc.DecoderSequence.graph is not None:
data = await dc.DecoderSequence.graph.json(diff=True)
data = replace_inf_nan_with_str(data)
_DCLibDebugPrinter.printer.add_decoder_state(data)
self.dcmodel.report_stats(_DCLibDebugPrinter.printer, decoder_step)

# handle queries w/o any TemplateVariables
if self.root_state.query_head.result is not None:
# one last call to debug_out to get the final state
await debug_out(self.decoder_step)
return (await self.postprocess([self.root_state.query_head.result]))[0]

# prepare tokenizer
self.tokenizer = self.model.get_tokenizer()
# prepare tokenizer
self.tokenizer = self.model.get_tokenizer()

# again check for tracing (if specified as decoder arg)
self.enable_tracing_if_needed()
# again check for tracing (if specified as decoder arg)
self.enable_tracing_if_needed()

# alternative execution mode where we only extract the initial prompt string
return_prompt_string = self.extra_kwargs.pop("return_prompt_string", False)
if return_prompt_string:
return self.root_state.prompt
# alternative execution mode where we only extract the initial prompt string
return_prompt_string = self.extra_kwargs.pop("return_prompt_string", False)
if return_prompt_string:
return self.root_state.prompt

# tokenize initial prompt
prompt_ids = await self.tokenize(self.root_state.prompt)
if self.dcmodel.bos_token_id is not None:
prompt_ids = [self.dcmodel.bos_token_id] + prompt_ids
# tokenize initial prompt
prompt_ids = await self.tokenize(self.root_state.prompt)
if self.dcmodel.bos_token_id is not None:
prompt_ids = [self.dcmodel.bos_token_id] + prompt_ids

prompt = self.tokenizer.tokenize(self.root_state.prompt, asbytes=True)
n = len(prompt)

# make sure that the initial prompt is not considered part of a variable
self.root_state = self.root_state.updated(variable_offset=n)
prompt = self.tokenizer.tokenize(self.root_state.prompt, asbytes=True)
n = len(prompt)
# make sure that the initial prompt is not considered part of a variable
self.root_state = self.root_state.updated(variable_offset=n)

decoder_args = self.decoder_kwargs.copy()
decoder_args = self.decoder_kwargs.copy()

# pass processor as decoder argument
decoder_args["modern_logits_processor"] = self.where_processor

# pass rewriter as decoder argument
decoder_args["modern_rewriter"] = self.rewrite_processor

if "__get_where__" in decoder_args:
return self.where

if "output_writer" in decoder_args:
set_dclib_debug_printer(decoder_args["output_writer"])
elif self.output_writer is not None:
set_dclib_debug_printer(self.output_writer)

if _DCLibDebugPrinter.printer is not None:
if hasattr(_DCLibDebugPrinter.printer, "records_graph"):
if _DCLibDebugPrinter.printer.records_graph:
dc.set_record_graph()
self.decoder_graph = dc.DecoderSequence.graph

# get decoder function
mode = decoder_args["decoder"].lower()
# handle dynamically-set decoding (e.g. set via @lmql.query(decoder="beam", n=2))
derived_mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs, decoder_args)
decoder_args = {**decoder_args, **extra_decoder_args}

# use derived decoder, if not set explicitly
if mode == "__dynamic__":
mode = derived_mode

decoder_fct = dc.get_decoder(mode)
self.validate_args(decoder_args, decoder_fct)

# alias max_length -> max_len
if "max_length" in decoder_args:
decoder_args["max_len"] = decoder_args["max_length"]
if not "max_len" in decoder_args.keys():
decoder_args["max_len"] = 2048

# parse show_speculative argument
if "show_speculative" in decoder_args.keys():
self.show_speculative = decoder_args.pop("show_speculative")
assert self.caching, "warning: show_speculative is only supported when caching is enabled."

# parse cache argument
if "cache" in decoder_args.keys():
cache_value = decoder_args.pop("cache")
if type(cache_value) is bool:
self.caching = cache_value
elif type(cache_value) is str:
self.caching = True
self.cache_file = cache_value
else:
assert False, "Invalid value for 'cache' parameter. Expected either a boolean (to enable/disable) or a string (to enable with a disk-based cache file)"
# pass processor as decoder argument
decoder_args["modern_logits_processor"] = self.where_processor
# pass rewriter as decoder argument
decoder_args["modern_rewriter"] = self.rewrite_processor

if "__get_where__" in decoder_args:
return self.where

if "output_writer" in decoder_args:
set_dclib_debug_printer(decoder_args["output_writer"])
elif self.output_writer is not None:
set_dclib_debug_printer(self.output_writer)

if _DCLibDebugPrinter.printer is not None:
if hasattr(_DCLibDebugPrinter.printer, "records_graph"):
if _DCLibDebugPrinter.printer.records_graph:
dc.set_record_graph()
self.decoder_graph = dc.DecoderSequence.graph

# get decoder function
mode = decoder_args["decoder"].lower()
# handle dynamically-set decoding (e.g. set via @lmql.query(decoder="beam", n=2))
derived_mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs, decoder_args)
decoder_args = {**decoder_args, **extra_decoder_args}
# use derived decoder, if not set explicitly
if mode == "__dynamic__":
mode = derived_mode
decoder_fct = dc.get_decoder(mode)
self.validate_args(decoder_args, decoder_fct)

# alias max_length -> max_len
if "max_length" in decoder_args:
decoder_args["max_len"] = decoder_args["max_length"]
if not "max_len" in decoder_args.keys():
decoder_args["max_len"] = 2048
# parse show_speculative argument
if "show_speculative" in decoder_args.keys():
self.show_speculative = decoder_args.pop("show_speculative")
assert self.caching, "warning: show_speculative is only supported when caching is enabled."

# parse cache argument
if "cache" in decoder_args.keys():
cache_value = decoder_args.pop("cache")
if type(cache_value) is bool:
self.caching = cache_value
elif type(cache_value) is str:
self.caching = True
self.cache_file = cache_value
else:
assert False, "Invalid value for 'cache' parameter. Expected either a boolean (to enable/disable) or a string (to enable with a disk-based cache file)"

# setup dcmodel for use
self.dcmodel.model_args = decoder_args
if self.caching:
self.dcmodel = dc.CachedDcModel(self.dcmodel, prompt_ids, cache_file=self.cache_file, show_speculative=self.show_speculative)
decoder_args["dcmodel"] = self.dcmodel
# setup dcmodel for use
self.dcmodel.model_args = decoder_args
if self.caching:
self.dcmodel = dc.CachedDcModel(self.dcmodel, prompt_ids, cache_file=self.cache_file, show_speculative=self.show_speculative)
decoder_args["dcmodel"] = self.dcmodel

assert len(prompt_ids) < decoder_args["max_len"], "The initial prompt already exceeds the provided max_len. Please increase the max_len or reduce the initial prompt (Initial prompt: '{}', max_len: {})".format(len(prompt_ids), decoder_args["max_len"])
assert len(prompt_ids) < decoder_args["max_len"], "The initial prompt already exceeds the provided max_len. Please increase the max_len or reduce the initial prompt (Initial prompt: '{}', max_len: {})".format(len(prompt_ids), decoder_args["max_len"])

# set step budget at least to max_len
step_budget = decoder_args.get("step_budget", max(1024, decoder_args.get("max_len", 1024)))
# set step budget at least to max_len
step_budget = decoder_args.get("step_budget", max(1024, decoder_args.get("max_len", 1024)))

with Context(self.model.get_tokenizer(), self.dcmodel.truncation_threshold):

try:
import time

Expand Down

0 comments on commit 8145eca

Please sign in to comment.