diff --git a/src/lmql/api/__init__.py b/src/lmql/api/__init__.py index 705e63a8..8313850f 100644 --- a/src/lmql/api/__init__.py +++ b/src/lmql/api/__init__.py @@ -13,6 +13,7 @@ from .scoring import ScoringResult from .serve import serve from inspect import * +from .blobs import Blob from lmql.runtime.tokenizer import tokenizer from lmql.runtime.loop import run_in_loop diff --git a/src/lmql/api/blobs.py b/src/lmql/api/blobs.py new file mode 100644 index 00000000..aa378e01 --- /dev/null +++ b/src/lmql/api/blobs.py @@ -0,0 +1,36 @@ +from weakref import WeakValueDictionary +import re + +blob_store = WeakValueDictionary() + +class Blob: + """ + Represents multi-modal data as part of an LMQL prompt. + """ + def __init__(self, data): + self.data = data + self.id = str(hash(data)) + + blob_store[self.id] = self + + @staticmethod + def resolve(id): + return blob_store.get(id) + + def __str__(self): + return f"" + + def __repr__(self): + return str(self) + + @staticmethod + def decode(text): + print("decode", [text]) + pattern = r"]*)\>" + # split text by pattern matches, and replace each match with the resolved blob + parts = re.split(pattern, text) + for i in range(1, len(parts), 2): + id = parts[i].split("id='")[1].split("'")[0] + print([id]) + parts[i] = Blob.resolve(id) + return text \ No newline at end of file diff --git a/src/lmql/api/llm.py b/src/lmql/api/llm.py index 922d0114..fa508e13 100644 --- a/src/lmql/api/llm.py +++ b/src/lmql/api/llm.py @@ -135,7 +135,7 @@ async def score(self, prompt: str, values: Union[str, List[str]], **kwargs) -> S try: dcmodel = self.adapter.get_dclib_model() with traced(str(self) + ".score"): - with dc.Context(self.adapter.get_tokenizer(), dcmodel.truncation_threshold): + with dc.Context(None, self.adapter.get_tokenizer(), dcmodel.truncation_threshold): return await dc_score(dcmodel, prompt, values, **kwargs) finally: dcmodel.close() diff --git a/src/lmql/runtime/bopenai/openai_api.py b/src/lmql/runtime/bopenai/openai_api.py index 32dc0c26..b333fd9a 100644 --- a/src/lmql/runtime/bopenai/openai_api.py +++ b/src/lmql/runtime/bopenai/openai_api.py @@ -14,6 +14,7 @@ import asyncio from lmql.runtime.stats import Stats +from lmql.api.blobs import Blob from lmql.runtime.tracing import Tracer from lmql.models.model_info import model_info @@ -397,7 +398,7 @@ async def completion_api(**kwargs): num_prompts = len(kwargs["prompt"]) timeout = kwargs.pop("timeout", 1.5) tracer = kwargs.pop("tracer", None) - + max_tokens = kwargs.get("max_tokens") # if no token limit is set, use 1024 as a generous chunk size # (completion models require max_tokens to be set) diff --git a/src/lmql/runtime/context.py b/src/lmql/runtime/context.py index acc30ab5..d5f36fd4 100644 --- a/src/lmql/runtime/context.py +++ b/src/lmql/runtime/context.py @@ -49,7 +49,8 @@ def pop_context(): _context.set(_context.get()[:-1]) class Context: - def __init__(self, tokenizer, truncation_threshold=-3e38): + def __init__(self, interpreter, tokenizer=None, truncation_threshold=-3e38): + self.interpreter = interpreter self.tokenizer = tokenizer self.truncation_threshold = truncation_threshold @@ -59,4 +60,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): pop_context() - return False \ No newline at end of file + return False + + @classmethod + def get(cls): + ctx = _context.get() + if len(ctx) == 0: + return None + return ctx[-1] \ No newline at end of file diff --git a/src/lmql/runtime/formatting.py b/src/lmql/runtime/formatting.py index 2422af76..93a14972 100644 --- a/src/lmql/runtime/formatting.py +++ b/src/lmql/runtime/formatting.py @@ -1,6 +1,8 @@ """ Formatting of values in a prompt context. """ +from lmql.api.blobs import Blob +from lmql.runtime.context import Context def unescape(s): return str(s).replace("[", "[[").replace("]", "]]") @@ -18,6 +20,10 @@ def is_chat_list(l): return True def format_chat(chat): + """ + Formats a list of dicts representing an OpenAI Chat model + input into an LMQL-compliant prompt string using tags. + """ qstring = "" for m in chat: @@ -34,5 +40,8 @@ def format(s): """ if is_chat_list(s): return format_chat(s) + + if isinstance(s, Blob): + return tag("media id='" + s.id + "'") return unescape(s) \ No newline at end of file diff --git a/src/lmql/runtime/interpreter.py b/src/lmql/runtime/interpreter.py index b4bf9d10..38a7d538 100644 --- a/src/lmql/runtime/interpreter.py +++ b/src/lmql/runtime/interpreter.py @@ -467,7 +467,7 @@ def process_query_string(self, s: str, first=False): # check if this is the first token in the prompt and it is a tag first_tag = s.startswith(" with ((ROLE)): - s = re.sub(r"", r"\n((\1)):", s) + s = re.sub(r"", r"\n((\1)):", s) # strip off leading newline if it was added due to a tag if first_tag: s = s[1:] s = unescape_qstring(s) @@ -950,126 +950,130 @@ async def run(self, fct, *args, **kwargs): if "input" in kwargs.keys() and kwargs["input"] == input: kwargs["input"] = self.input - # prepare initial program state - context = LMQLContext(self, None, "") - query_head = InterpretationHead(fct, context, args, kwargs) - self.root_state = PromptState(interpreter=self, subinterpreters={}, - variable=None, prompt="", stmt_buffer=[], - query_head=query_head, program_state=context.program_state, - query_args=None, variable_args=None, - recurring_variable_counter={}, variable_offset=0, - valid=None, final=None, mask=None, - stopping_phrases=None, where=None, - tail=None) - self.root_state = await self.advance(self.root_state) - - async def debug_out(decoder_step): - if PromptInterpreter.main != self: - return - if _DCLibDebugPrinter.printer is not None and dc.DecoderSequence.graph is not None: - data = await dc.DecoderSequence.graph.json(diff=True) - data = replace_inf_nan_with_str(data) - _DCLibDebugPrinter.printer.add_decoder_state(data) - self.dcmodel.report_stats(_DCLibDebugPrinter.printer, decoder_step) - - # handle queries w/o any TemplateVariables - if self.root_state.query_head.result is not None: - with Context(self.model.get_tokenizer(), self.dcmodel.truncation_threshold): + with Context(self) as ctx: + # prepare initial program state + context = LMQLContext(self, None, "") + query_head = InterpretationHead(fct, context, args, kwargs) + self.root_state = PromptState(interpreter=self, subinterpreters={}, + variable=None, prompt="", stmt_buffer=[], + query_head=query_head, program_state=context.program_state, + query_args=None, variable_args=None, + recurring_variable_counter={}, variable_offset=0, + valid=None, final=None, mask=None, + stopping_phrases=None, where=None, + tail=None) + self.root_state = await self.advance(self.root_state) + + # update context + ctx.tokenizer = self.model.get_tokenizer() + ctx.truncation_threshold = self.dcmodel.truncation_threshold + + async def debug_out(decoder_step): + if PromptInterpreter.main != self: + return + if _DCLibDebugPrinter.printer is not None and dc.DecoderSequence.graph is not None: + data = await dc.DecoderSequence.graph.json(diff=True) + data = replace_inf_nan_with_str(data) + _DCLibDebugPrinter.printer.add_decoder_state(data) + self.dcmodel.report_stats(_DCLibDebugPrinter.printer, decoder_step) + + # handle queries w/o any TemplateVariables + if self.root_state.query_head.result is not None: # one last call to debug_out to get the final state await debug_out(self.decoder_step) return (await self.postprocess([self.root_state.query_head.result]))[0] - # prepare tokenizer - self.tokenizer = self.model.get_tokenizer() + # prepare tokenizer + self.tokenizer = self.model.get_tokenizer() - # again check for tracing (if specified as decoder arg) - self.enable_tracing_if_needed() + # again check for tracing (if specified as decoder arg) + self.enable_tracing_if_needed() - # alternative execution mode where we only extract the initial prompt string - return_prompt_string = self.extra_kwargs.pop("return_prompt_string", False) - if return_prompt_string: - return self.root_state.prompt + # alternative execution mode where we only extract the initial prompt string + return_prompt_string = self.extra_kwargs.pop("return_prompt_string", False) + if return_prompt_string: + return self.root_state.prompt - # tokenize initial prompt - prompt_ids = await self.tokenize(self.root_state.prompt) - if self.dcmodel.bos_token_id is not None: - prompt_ids = [self.dcmodel.bos_token_id] + prompt_ids + # tokenize initial prompt + prompt_ids = await self.tokenize(self.root_state.prompt) + if self.dcmodel.bos_token_id is not None: + prompt_ids = [self.dcmodel.bos_token_id] + prompt_ids - prompt = self.tokenizer.tokenize(self.root_state.prompt, asbytes=True) - n = len(prompt) - - # make sure that the initial prompt is not considered part of a variable - self.root_state = self.root_state.updated(variable_offset=n) + prompt = self.tokenizer.tokenize(self.root_state.prompt, asbytes=True) + n = len(prompt) + + # make sure that the initial prompt is not considered part of a variable + self.root_state = self.root_state.updated(variable_offset=n) - decoder_args = self.decoder_kwargs.copy() + decoder_args = self.decoder_kwargs.copy() - # pass processor as decoder argument - decoder_args["modern_logits_processor"] = self.where_processor - - # pass rewriter as decoder argument - decoder_args["modern_rewriter"] = self.rewrite_processor - - if "__get_where__" in decoder_args: - return self.where - - if "output_writer" in decoder_args: - set_dclib_debug_printer(decoder_args["output_writer"]) - elif self.output_writer is not None: - set_dclib_debug_printer(self.output_writer) - - if _DCLibDebugPrinter.printer is not None: - if hasattr(_DCLibDebugPrinter.printer, "records_graph"): - if _DCLibDebugPrinter.printer.records_graph: - dc.set_record_graph() - self.decoder_graph = dc.DecoderSequence.graph - - # get decoder function - mode = decoder_args["decoder"].lower() - # handle dynamically-set decoding (e.g. set via @lmql.query(decoder="beam", n=2)) - derived_mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs, decoder_args) - decoder_args = {**decoder_args, **extra_decoder_args} - - # use derived decoder, if not set explicitly - if mode == "__dynamic__": - mode = derived_mode - - decoder_fct = dc.get_decoder(mode) - self.validate_args(decoder_args, decoder_fct) - - # alias max_length -> max_len - if "max_length" in decoder_args: - decoder_args["max_len"] = decoder_args["max_length"] - if not "max_len" in decoder_args.keys(): - decoder_args["max_len"] = 2048 - - # parse show_speculative argument - if "show_speculative" in decoder_args.keys(): - self.show_speculative = decoder_args.pop("show_speculative") - assert self.caching, "warning: show_speculative is only supported when caching is enabled." - - # parse cache argument - if "cache" in decoder_args.keys(): - cache_value = decoder_args.pop("cache") - if type(cache_value) is bool: - self.caching = cache_value - elif type(cache_value) is str: - self.caching = True - self.cache_file = cache_value - else: - assert False, "Invalid value for 'cache' parameter. Expected either a boolean (to enable/disable) or a string (to enable with a disk-based cache file)" + # pass processor as decoder argument + decoder_args["modern_logits_processor"] = self.where_processor + + # pass rewriter as decoder argument + decoder_args["modern_rewriter"] = self.rewrite_processor + + if "__get_where__" in decoder_args: + return self.where + + if "output_writer" in decoder_args: + set_dclib_debug_printer(decoder_args["output_writer"]) + elif self.output_writer is not None: + set_dclib_debug_printer(self.output_writer) + + if _DCLibDebugPrinter.printer is not None: + if hasattr(_DCLibDebugPrinter.printer, "records_graph"): + if _DCLibDebugPrinter.printer.records_graph: + dc.set_record_graph() + self.decoder_graph = dc.DecoderSequence.graph + + # get decoder function + mode = decoder_args["decoder"].lower() + # handle dynamically-set decoding (e.g. set via @lmql.query(decoder="beam", n=2)) + derived_mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs, decoder_args) + decoder_args = {**decoder_args, **extra_decoder_args} + + # use derived decoder, if not set explicitly + if mode == "__dynamic__": + mode = derived_mode + + decoder_fct = dc.get_decoder(mode) + self.validate_args(decoder_args, decoder_fct) + + # alias max_length -> max_len + if "max_length" in decoder_args: + decoder_args["max_len"] = decoder_args["max_length"] + if not "max_len" in decoder_args.keys(): + decoder_args["max_len"] = 2048 + + # parse show_speculative argument + if "show_speculative" in decoder_args.keys(): + self.show_speculative = decoder_args.pop("show_speculative") + assert self.caching, "warning: show_speculative is only supported when caching is enabled." + + # parse cache argument + if "cache" in decoder_args.keys(): + cache_value = decoder_args.pop("cache") + if type(cache_value) is bool: + self.caching = cache_value + elif type(cache_value) is str: + self.caching = True + self.cache_file = cache_value + else: + assert False, "Invalid value for 'cache' parameter. Expected either a boolean (to enable/disable) or a string (to enable with a disk-based cache file)" - # setup dcmodel for use - self.dcmodel.model_args = decoder_args - if self.caching: - self.dcmodel = dc.CachedDcModel(self.dcmodel, prompt_ids, cache_file=self.cache_file, show_speculative=self.show_speculative) - decoder_args["dcmodel"] = self.dcmodel + # setup dcmodel for use + self.dcmodel.model_args = decoder_args + if self.caching: + self.dcmodel = dc.CachedDcModel(self.dcmodel, prompt_ids, cache_file=self.cache_file, show_speculative=self.show_speculative) + decoder_args["dcmodel"] = self.dcmodel - assert len(prompt_ids) < decoder_args["max_len"], "The initial prompt already exceeds the provided max_len. Please increase the max_len or reduce the initial prompt (Initial prompt: '{}', max_len: {})".format(len(prompt_ids), decoder_args["max_len"]) + assert len(prompt_ids) < decoder_args["max_len"], "The initial prompt already exceeds the provided max_len. Please increase the max_len or reduce the initial prompt (Initial prompt: '{}', max_len: {})".format(len(prompt_ids), decoder_args["max_len"]) - # set step budget at least to max_len - step_budget = decoder_args.get("step_budget", max(1024, decoder_args.get("max_len", 1024))) + # set step budget at least to max_len + step_budget = decoder_args.get("step_budget", max(1024, decoder_args.get("max_len", 1024))) - with Context(self.model.get_tokenizer(), self.dcmodel.truncation_threshold): + try: import time