diff --git a/src/lmql/api/__init__.py b/src/lmql/api/__init__.py index a642a3e8..3fe23512 100644 --- a/src/lmql/api/__init__.py +++ b/src/lmql/api/__init__.py @@ -13,6 +13,7 @@ from .scoring import ScoringResult from .serve import serve from inspect import * +from lmql.runtime.tokenizer import tokenizer async def generate(prompt: str, max_tokens: Optional[int] = None, model: Optional[Union[LLM, str]] = None, **kwargs): """ diff --git a/src/lmql/api/scoring.py b/src/lmql/api/scoring.py index 3ac00029..5a5619bb 100644 --- a/src/lmql/api/scoring.py +++ b/src/lmql/api/scoring.py @@ -78,7 +78,7 @@ def argmax(self, agg="sum") -> str: def __str__(self): return "lmql.ScoringResult(model='{}')\n".format(self.model_identifier) + \ - "\n".join([f"-{c}: {score}" for c,score in zip(self.continuations, self.scores(agg="sum"))]) + "\n".join([f"-{str([c])[1:-1]}: {score}" for c,score in zip(self.continuations, self.scores(agg="sum"))]) async def dc_score(model: dc.DcModel, prompt, values, **kwargs): """ diff --git a/src/lmql/language/compiler.py b/src/lmql/language/compiler.py index 990e9296..94466b67 100644 --- a/src/lmql/language/compiler.py +++ b/src/lmql/language/compiler.py @@ -11,6 +11,7 @@ import lmql.runtime.lmql_runtime as lmql_runtime from lmql.language.fragment_parser import (FragmentParserError, LanguageFragmentParser, + LMQLDistributionClause, double_unescape_str, LMQLDecoderConfiguration, LMQLQuery) from lmql.language.qstrings import (DistributionVariable, FExpression, @@ -90,9 +91,10 @@ def __init__(self): self.prologue_vars = set() self.free_vars = set() self.written_vars = set() - self.defined_constraints = set() + self.query = None + def scope_prologue(self, query: LMQLQuery): if query.prologue is None: return @@ -111,6 +113,8 @@ def scope(self, query: LMQLQuery): # collect defined vars in prologue self.scope_prologue(query) + self.query = query + # collect defined vars in prompt for p in query.prompt: self.visit(p) @@ -144,6 +148,16 @@ def visit_BoolOp(self, node: ast.BoolOp) -> Any: self.scope_Constant(node.values[0]) for constraint in node.values[1:]: self.visit_where(constraint) + elif is_query_string_with_distribution(node): + assert len(node.values) == 2, "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'" + distribution_in_clause = node.values[1] + assert isinstance(distribution_in_clause, ast.Compare), "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'" + var = distribution_in_clause.left + assert isinstance(var, ast.Name), "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'" + self.distribution_vars = set([var.id]) + assert len(distribution_in_clause.comparators) == 1, "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'" + self.query.distribution = LMQLDistributionClause(var.id, distribution_in_clause.comparators[0]) + self.scope_Constant(node.values[0]) else: super().generic_visit(node) @@ -276,6 +290,12 @@ def is_query_string_with_constraints(node: ast.BoolOp): left_most_operand = node.values[0] return type(left_most_operand) is ast.Constant and type(left_most_operand.value) is str and isinstance(node.op, ast.And) +def is_query_string_with_distribution(node: ast.BoolOp): + if len(node.values) < 1: + return False + left_most_operand = node.values[0] + return type(left_most_operand) is ast.Constant and type(left_most_operand.value) is str and isinstance(node.op, ast.Or) + def attr(s): names = s.split(".") element = ast.Name(names[0], ast.Load()) @@ -357,6 +377,9 @@ def visit_BoolOp(self, node: ast.BoolOp) -> Any: elif len(node.values[1:]) == 0: constraints_expression = None return self.transform_Constant(left_most_operand, constraints = constraints_expression) + elif is_query_string_with_distribution(node): + left_most_operand = node.values[0] + return self.transform_Constant(left_most_operand) return self.generic_visit(node) def visit_FunctionDef(self, node: FunctionDef) -> Any: @@ -934,6 +957,4 @@ def compile(self, filepath): return LMQLModule(output_file, lmql_code=lmql_code, output_variables=[v for v in scope.defined_vars]) except FragmentParserError as e: - sys.stderr.write("error: " + str(e) + "\n") - sys.exit(1) - + raise RuntimeError("parsing error: {}.\nFailed when parsing:\n {}".format(e, lmql_code)) \ No newline at end of file diff --git a/src/lmql/language/fragment_parser.py b/src/lmql/language/fragment_parser.py index 6c8c8055..4f7fb9e8 100644 --- a/src/lmql/language/fragment_parser.py +++ b/src/lmql/language/fragment_parser.py @@ -164,6 +164,7 @@ def parse(self, readline): self.prologue_transform() self.inline_where_transform() + self.inline_distribution_transform() self.ast_parse() self.syntax_validation() self.ast_transform() @@ -177,7 +178,14 @@ def inline_where_transform(self): lookahead = prompt_tokens[i+1] if tok.type == tokenize.STRING and lookahead.type == tokenize.NAME and lookahead.string == "where": prompt_tokens[i+1] = tokenize.TokenInfo(type=tokenize.OP, string="and", start=lookahead.start, end=lookahead.end, line=lookahead.line) - + + def inline_distribution_transform(self): + prompt_tokens = self.query.prompt_str + for i in range(len(prompt_tokens) - 1): + tok = prompt_tokens[i] + lookahead = prompt_tokens[i+1] + if tok.type == tokenize.STRING and lookahead.type == tokenize.NAME and lookahead.string == "distribution": + prompt_tokens[i+1] = tokenize.TokenInfo(type=tokenize.OP, string="or", start=lookahead.start, end=lookahead.end, line=lookahead.line) def prologue_transform(self): # translate prologue tokens into str @@ -251,7 +259,7 @@ def digest(self, tok): self.state = "decode" return - if is_keyword(tok, "where"): + if is_keyword(tok, "where") or is_keyword(tok, "distribution"): self.query.prompt_str = self.query.prologue + [tok] self.query.prologue = [] self.state = "prompt" @@ -279,7 +287,7 @@ def digest(self, tok): if self.query.prompt_str[-1].type != tokenize.STRING: self.state = "where" return - + if is_keyword(tok, "FROM"): self.state = "from" return @@ -287,8 +295,9 @@ def digest(self, tok): self.state = "scoring" return if is_keyword(tok, "DISTRIBUTION"): - self.state = "distribution" - return + if self.query.prompt_str[-1].type != tokenize.STRING: + self.state = "distribution" + return # if last token is NAME and current is str if len(self.query.prompt_str) > 0 and self.query.prompt_str[-1].type == tokenize.NAME and \ diff --git a/src/lmql/models/lmtp/backends/__main__.py b/src/lmql/models/lmtp/backends/__main__.py index 11d44403..eff4660c 100644 --- a/src/lmql/models/lmtp/backends/__main__.py +++ b/src/lmql/models/lmtp/backends/__main__.py @@ -3,7 +3,7 @@ from lmql.models.lmtp.backends import LMTPModel from lmql.models.lmtp.lmtp_scheduler import TokenStreamer -from lmql.runtime.tokenizer import load_tokenizer +import lmql import transformers @@ -11,7 +11,7 @@ if __name__ == "__main__": backend = sys.argv[1] model: LMTPModel = LMTPModel.load(backend) - t = load_tokenizer("huggyllama/llama-7b") + t = lmql.tokenizer("huggyllama/llama-7b") s = sys.argv[2] input_ids = [t.bos_token_id] + t(s)["input_ids"] diff --git a/src/lmql/models/lmtp/lmtp_client.py b/src/lmql/models/lmtp/lmtp_client.py index 2caab26a..ab775120 100644 --- a/src/lmql/models/lmtp/lmtp_client.py +++ b/src/lmql/models/lmtp/lmtp_client.py @@ -112,10 +112,10 @@ async def interactive_client(): async with aiohttp.ClientSession() as session: async with session.ws_connect('http://workstation:8888') as ws: - from lmql.runtime.tokenizer import load_tokenizer + from lmql.runtime.tokenizer import tokenizer model = sys.argv[1] - tokenizer = load_tokenizer(model) + tokenizer = tokenizer(model) client = LMTPWebSocketClient(model, ws) client.connect() diff --git a/src/lmql/models/lmtp/lmtp_dcmodel.py b/src/lmql/models/lmtp/lmtp_dcmodel.py index 3aa73654..d3825d0d 100644 --- a/src/lmql/models/lmtp/lmtp_dcmodel.py +++ b/src/lmql/models/lmtp/lmtp_dcmodel.py @@ -4,7 +4,7 @@ """ from lmql.runtime.dclib.dclib_model import DcModel -from lmql.runtime.tokenizer import load_tokenizer +from lmql.runtime.tokenizer import tokenizer from .lmtp_async import LMTPAsyncClient import lmql.runtime.dclib as dc import asyncio @@ -496,7 +496,7 @@ def __init__(self) -> None: def get_tokenizer(self): if self._tokenizer is None: - self._tokenizer = load_tokenizer(this.tokenizer_identifier, **this.kwargs) + self._tokenizer = tokenizer(this.tokenizer_identifier, **this.kwargs) self.served_model = self return self._tokenizer diff --git a/src/lmql/models/lmtp/lmtp_langchain.py b/src/lmql/models/lmtp/lmtp_langchain.py index 71c1933c..fec33d5a 100644 --- a/src/lmql/models/lmtp/lmtp_langchain.py +++ b/src/lmql/models/lmtp/lmtp_langchain.py @@ -18,8 +18,8 @@ from langchain.llms.utils import enforce_stop_tokens from langchain.schema import LLMResult -from lmql.runtime.tokenizer import LMQLTokenizer, load_tokenizer -from lmql.runtime.model_registry import LMQLModelRegistry +import lmql +from lmql.runtime.tokenizer import LMQLTokenizer if TYPE_CHECKING: from tenacity import RetryCallState @@ -232,7 +232,7 @@ def _get_params(self, _kwarg_dict: Dict[str, Any]) -> Any: def _get_tokenizer(self) -> LMQLTokenizer: if self.tokenizer is None: - self.tokenizer = LMQLModelRegistry.get(self.model).get_tokenizer() + self.tokenizer = lmql.model(self.model).get_tokenizer() return self.tokenizer def _call( diff --git a/src/lmql/runtime/bopenai/openai_api.py b/src/lmql/runtime/bopenai/openai_api.py index 808acd4b..e5c2f336 100644 --- a/src/lmql/runtime/bopenai/openai_api.py +++ b/src/lmql/runtime/bopenai/openai_api.py @@ -13,7 +13,6 @@ import time import asyncio -from lmql.runtime.tokenizer import load_tokenizer from lmql.runtime.stats import Stats from lmql.models.model_info import model_info @@ -72,13 +71,7 @@ async def complete(**kwargs): global tokenizers tokenizers = {} -def tokenize(text, model, openai_byte_encoding=False): - global tokenizers - if not model in tokenizers: - tokenizer = load_tokenizer("tiktoken:" + model) - tokenizers[model] = tokenizer - else: - tokenizer = tokenizers[model] +def tokenize(text, tokenizer, openai_byte_encoding=False): ids = tokenizer(text)["input_ids"] raw = tokenizer.decode_bytes(ids) if openai_byte_encoding: @@ -175,9 +168,12 @@ async def chat_api(**kwargs): num_prompts = len(kwargs["prompt"]) max_tokens = kwargs.get("max_tokens", 0) model = kwargs["model"] + api_config = kwargs.get("api_config", {}) + tokenizer = api_config.get("tokenizer") + assert tokenizer is not None, "internal error: chat_api expects an 'api_config' with a 'tokenizer: LMQLTokenizer' mapping in your API payload" assert "logit_bias" not in kwargs.keys(), f"Chat API models do not support advanced constraining of the output, please use no or less complicated constraints." - prompt_tokens = tokenize(kwargs["prompt"][0], model=model, openai_byte_encoding=True) + prompt_tokens = tokenize(kwargs["prompt"][0], tokenizer=tokenizer, openai_byte_encoding=True) timeout = kwargs.pop("timeout", 1.5) @@ -229,6 +225,8 @@ async def chat_api(**kwargs): del kwargs["prompt"] kwargs["messages"] = messages + needs_space = True # messages[-1]["content"][-1] != " " + del kwargs["logprobs"] async with CapacitySemaphore(num_prompts * max_tokens): @@ -237,7 +235,6 @@ async def chat_api(**kwargs): stream_start = time.time() async with aiohttp.ClientSession() as session: - api_config = kwargs.get("api_config", {}) endpoint, headers = get_endpoint_and_headers(kwargs) if api_config.get("verbose", False) or os.environ.get("LMQL_VERBOSE", "0") == "1" or api_config.get("chatty_openai", False): @@ -327,7 +324,10 @@ async def chunk_timer(): }) continue text = delta["content"] - tokens = tokenize((" " if received_text == "" else "") + text, model=model, openai_byte_encoding=True) + if len(text) == 0: + continue + + tokens = tokenize((" " if received_text == "" and needs_space else "") + text, tokenizer=tokenizer, openai_byte_encoding=True) received_text += text # convert tokens to OpenAI format diff --git a/src/lmql/runtime/dclib/dclib_cache.py b/src/lmql/runtime/dclib/dclib_cache.py index 74b4d1d5..a7727308 100644 --- a/src/lmql/runtime/dclib/dclib_cache.py +++ b/src/lmql/runtime/dclib/dclib_cache.py @@ -143,14 +143,18 @@ async def get_mask(self, s: DecoderSequence, **kwargs): async def get_keys(self, s: DecoderSequence, edge_type: str, **kwargs): kwargs = {**self.delegate.model_args, **kwargs} + reuse_context = kwargs.get("cache_reuse_context", set()) keys = [] # check for sample-id - if s.data("dc-edge-type"): + if s.data("dc-edge-type") and edge_type is not None: + dc_edge_type = s.data("dc-edge-type") # if the edge type aligns with dc-edge-type, use that instead (includes a unique sample id if available) - if s.data("dc-edge-type").startswith(edge_type): - edge_type = s.data("dc-edge-type") + if dc_edge_type.startswith(edge_type): + if not dc_edge_type in reuse_context: + reuse_context.add(dc_edge_type) + edge_type = dc_edge_type # compute logits mask mask = (await self.get_mask(s, **kwargs)).logits_mask[0] @@ -269,7 +273,11 @@ async def op_sample(seqs): temperature = kwargs.get('temperature', 1.0) sampling_mode = "top-1" if temperature == 0.0 else "sample-{}".format(temperature) - cache_entries = [await self.get_cache(s, sampling_mode, user_data=True, **kwargs) for s in seqs] + # make sure that each uniquely sampled trajectory in the cache, cannot be used + # twice as a result of sampling (e.g. when sampling multiple times from the same sequence) + cache_reuse_context = set() + + cache_entries = [await self.get_cache(s, sampling_mode, user_data=True, cache_reuse_context=cache_reuse_context, **kwargs) for s in seqs] cached_cont = [e[1] for e in cache_entries] cache_keys = [e[0] for e in cache_entries] diff --git a/src/lmql/runtime/dclib/decoders.py b/src/lmql/runtime/dclib/decoders.py index b48530a3..a7cc6ca6 100644 --- a/src/lmql/runtime/dclib/decoders.py +++ b/src/lmql/runtime/dclib/decoders.py @@ -2,7 +2,6 @@ import numpy as np from typing import List, Any, Union, Optional, Dict -from lmql.runtime.tokenizer import load_tokenizer from lmql.runtime.dclib.dclib_array import DataArray, sum_scorer, alpha_length_normalized, alpha_length_normalized_det from lmql.runtime.dclib.dclib_seq import next_is_deterministic import lmql.runtime.dclib as dc diff --git a/src/lmql/runtime/interpreter.py b/src/lmql/runtime/interpreter.py index d0792112..a5a97f38 100644 --- a/src/lmql/runtime/interpreter.py +++ b/src/lmql/runtime/interpreter.py @@ -334,6 +334,8 @@ async def advance(self, state: PromptState): query_args_after_last_continue = query_args program_variables_after_last_continue = None prompt = state.prompt + recurring_variable_counter = state.recurring_variable_counter.copy() + distribution_reached = False query_head = state.query_head @@ -359,12 +361,16 @@ async def continue_for_more_prompt_stmts(): # return context used for last continue_ return query_head.context - # disable DecoderSequence.graph for the duration of executing the prompt + def format_buffer(): + return [s if type(s) is str else s.name for s in stmt_buffer if s is not advance] + # disable DecoderSequence.graph for the duration of executing the prompt try: while variable is None and query_head.result is None: if len(stmt_buffer) == 0 and variable is None: await continue_for_more_prompt_stmts() + if distribution_reached: + assert len(stmt_buffer) == 0, "error: distribution variable must be the last statement in a prompt, but found {}".format(format_buffer()) s = stmt_buffer[0] @@ -394,9 +400,9 @@ async def continue_for_more_prompt_stmts(): variable_args["constraints"] = ops.FixedValueOp([ops.Var(variable)], variable_value, prompt_value) # keep track of number of times a variable with this name has been decoded - if variable not in state.recurring_variable_counter.keys(): - state.recurring_variable_counter[s.name] = -1 - state.recurring_variable_counter[s.name] += 1 + if variable not in recurring_variable_counter.keys(): + recurring_variable_counter[s.name] = -1 + recurring_variable_counter[s.name] += 1 stmt_buffer = stmt_buffer[1:] break @@ -404,7 +410,8 @@ async def continue_for_more_prompt_stmts(): # distribution variables are skipped here, as they are handled in a postprocessing step after returning an LMQL result # self.query_head must terminate after this part of the prompt (ensure by validation) stmt_buffer = stmt_buffer[1:] - assert len([s for s in stmt_buffer if s is not advance]) == 0, "Distribution variables must be the last statement in a prompt, but is {}".format(stmt_buffer) + assert len([s for s in stmt_buffer if s is not advance]) == 0, "error: distribution variable must be the last statement in a prompt, but found {}".format(format_buffer()) + distribution_reached = True # this will consume the set_distribution call elif s is advance: query_head: InterpretationHead = query_head.copy() @@ -433,7 +440,8 @@ async def continue_for_more_prompt_stmts(): variable_args=variable_args, stmt_buffer=stmt_buffer, query_head=query_head, - program_state=program_state + program_state=program_state, + recurring_variable_counter=recurring_variable_counter ) def process_query_string(self, s: str, first=False): @@ -632,6 +640,8 @@ def node_data(op): async def debugger_output(self, state: PromptState, s: dc.DecoderSequence, valid, is_final, mask, stopping_phrases, program_variables, trace, text, where): if self.output_writer is not None: await self.output_writer.add_interpreter_head_state(state.variable, 0, state.prompt + text, where, trace, valid, is_final, mask, len(s.input_ids), program_variables) + if hasattr(self.output_writer, "add_sequence_state"): + await self.output_writer.add_sequence_state(s) async def where_processor(self, seqs, additional_logits_processor_mask, **kwargs): zipped_task_inputs = zip(seqs, additional_logits_processor_mask, range(len(seqs))) @@ -959,9 +969,13 @@ async def debug_out(decoder_step): # get decoder function mode = decoder_args["decoder"].lower() # handle dynamically-set decoding (e.g. set via @lmql.query(decoder="beam", n=2)) + derived_mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs, decoder_args) + decoder_args = {**decoder_args, **extra_decoder_args} + + # use derived decoder, if not set explicitly if mode == "__dynamic__": - mode, extra_decoder_args = self.derive_decoder_args(self.extra_kwargs) - decoder_args = {**decoder_args, **extra_decoder_args} + mode = derived_mode + decoder_fct = dc.get_decoder(mode) self.validate_args(decoder_args, decoder_fct) @@ -1053,6 +1067,10 @@ async def debug_out(decoder_step): for i,s in enumerate(result_sequences): state = self.interpreter_state_from_user_data(s) + + if hasattr(self.output_writer, "add_sequence_state"): + await self.output_writer.add_sequence_state(s) + if state.query_head.result is not None: results.append(state.query_head.result) else: @@ -1072,7 +1090,10 @@ async def debug_out(decoder_step): "openai_chunksize", "step_budget", "stats", "performance_stats", "cache", "show_speculative", "openai_nonstop", "chunksize", "alpha", "verbose"] - def derive_decoder_args(self, extra_kwargs): + def derive_decoder_args(self, extra_kwargs, existing_args=None): + # if no existing args are provided, use no args + existing_args = existing_args or {} + # default is argmax decoder = extra_kwargs.get("decoder", "argmax") # if temperature is != 0, use 'sample' @@ -1092,6 +1113,10 @@ def derive_decoder_args(self, extra_kwargs): for eda in PromptInterpreter.EXTRA_DECODER_ARGS: if eda in extra_kwargs.keys(): decoder_args[eda] = extra_kwargs[eda] + + # underscore prefixed args are only used if existing_args does not already contain the arg + if "_" + eda in extra_kwargs.keys() and not eda in existing_args.keys(): + decoder_args[eda] = extra_kwargs["_" + eda] return decoder, decoder_args diff --git a/src/lmql/runtime/openai_integration.py b/src/lmql/runtime/openai_integration.py index f76dedec..37cab95e 100644 --- a/src/lmql/runtime/openai_integration.py +++ b/src/lmql/runtime/openai_integration.py @@ -6,6 +6,7 @@ from typing import Any, Callable, List, Optional, Union import numpy as np +import random import lmql.runtime.masks as masks import lmql.runtime.bopenai as openai @@ -15,7 +16,7 @@ from lmql.runtime.dclib.dclib_seq import (DecoderSequence, deepcopy, deepmerge, detseq, is_deterministic) from lmql.runtime.stats import Stats -from lmql.runtime.tokenizer import load_tokenizer +from lmql.runtime.tokenizer import tokenizer from lmql.runtime.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from lmql.utils import nputil from lmql.runtime.token_distribution import TokenDistribution @@ -64,6 +65,7 @@ class CompletionCall: input_ids: np.ndarray kwargs: Any stopping_phrases: List[str] = None + sampling_mode: str = None # true iff inverting the api_mask leads to a smaller mask invert: bool = False @@ -92,11 +94,15 @@ def continuation_type(self): mask_key_segment = "-".join(mask_key_segment) else: mask_key_segment = "*" + + if self.sampling_mode is not None: + mask_key_segment += "-" + self.sampling_mode + return f"{parameter_values_key_segment}-{mask_key_segment}" class DclibOpenAiModel(DcModel): - def __init__(self, *args, endpoint=None, **kwargs): - super().__init__(*args, truncation_threshold=-12000, init_workers=False, **kwargs) + def __init__(self, model, tokenizer, endpoint=None, **kwargs): + super().__init__(model, tokenizer, truncation_threshold=-12000, init_workers=False, **kwargs) self.mock = kwargs.get("mock", False) @@ -117,7 +123,10 @@ def __init__(self, *args, endpoint=None, **kwargs): self.num_billed_tokens = {} self.num_requests = 0 + # prepare API config args for OpenAI API calling layer + kwargs["tokenizer"] = tokenizer self.api_config = {**({"endpoint": endpoint} if endpoint is not None else {}), **kwargs} + self.timeout = kwargs.get("chunk_timeout", 1.5 if not self.mock else 4.5) self.stats = Stats("openai") @@ -131,7 +140,7 @@ def log_billable_tokens(self, n: int): def log_queries(self, n: int): pass # openai keeps track of queries via bopenai - def prepare_completion_call(self, s, mask, **kwargs): + def prepare_completion_call(self, s, mask, sampling_mode, **kwargs): """ Computes an API compatible mask from the provided logit mask, as well as the required mode of completion. @@ -144,7 +153,7 @@ def prepare_completion_call(self, s, mask, **kwargs): stopping_phrases = s.data("head").stopping_phrases["text"] if mask is None: - return CompletionCall("*", None, s.input_ids, kwargs, stopping_phrases=stopping_phrases) + return CompletionCall("*", None, s.input_ids, kwargs, stopping_phrases=stopping_phrases, sampling_mode=sampling_mode) invert = False num_allowed = masks.mask_num_allowed(mask) @@ -156,19 +165,19 @@ def prepare_completion_call(self, s, mask, **kwargs): # check for case if masks.mask_is_allowed(mask, self.eos_token_id): - return CompletionCall("fixed", token, s.input_ids, kwargs, stopping_phrases=stopping_phrases) + return CompletionCall("fixed", token, s.input_ids, kwargs, stopping_phrases=stopping_phrases, sampling_mode=sampling_mode) else: # otherwise we can treat this as a score call - return CompletionCall("fixed", token, s.input_ids, kwargs, stopping_phrases=stopping_phrases) + return CompletionCall("fixed", token, s.input_ids, kwargs, stopping_phrases=stopping_phrases, sampling_mode=sampling_mode) elif num_allowed < self.tokenizer.model_vocab_size: if self.tokenizer.model_vocab_size - num_allowed > num_allowed: # if we have to mask more than half of the tokens, we should just invert the masking invert = True else: # num_allowed == mask.shape[-1] (full vocabulary) - return CompletionCall("*", None, s.input_ids, kwargs, stopping_phrases=stopping_phrases) + return CompletionCall("*", None, s.input_ids, kwargs, stopping_phrases=stopping_phrases, sampling_mode=sampling_mode) # num_allowed < mask.shape[-1] and num_allowed > 1 (needs mask) - return CompletionCall("complete", mask, s.input_ids, kwargs, invert=invert, stopping_phrases=stopping_phrases) + return CompletionCall("complete", mask, s.input_ids, kwargs, invert=invert, stopping_phrases=stopping_phrases, sampling_mode=sampling_mode) async def api_score(self, input_ids, offset): if len(input_ids) > 0 and input_ids[0] == self.tokenizer.bos_token_id: @@ -360,8 +369,11 @@ def count_billed_tokens(self, n, model): self.num_billed_tokens[model] += n self.num_requests += 1 - async def completion_buffer(self, seqs, temperature=1, **kwargs): + async def completion_buffer(self, seqs, temperature=1, sampling_modes=None, **kwargs): kwargs.update({"temperature": temperature}) + + if sampling_modes is None: + sampling_modes = ["top-1" for _ in range(len(seqs))] async def get_buffer(i, s): with self.stats.timer("logit_masks"): @@ -376,7 +388,7 @@ async def get_buffer(i, s): s.user_data = deepmerge(deepcopy(s.user_data), logits_mask_result.user_data[0]) s.user_data["set_by"] = "where" - completion_call = self.prepare_completion_call(s, logits_mask, **kwargs) + completion_call = self.prepare_completion_call(s, logits_mask, sampling_modes[i], **kwargs) # if no masking is required, we can use cached continuations if available if s.data("openai-continuations") is not None: @@ -401,10 +413,11 @@ async def get_buffer(i, s): ) completion_result = await self.async_complete(completion_call) + # eagerly expand and cache full completion if a cache_delegate is available if self.cache_delegate is not None: await self.expand_and_cache(s, completion_result, - "top-1" if temperature == 0.0 else f"sample-{temperature}", + sampling_modes[i], logprobs=kwargs.get("logprobs", 1)) assert not await completion_result.buffer.empty(), "Completion result is empty on arrival: {}".format(str([await self.detokenize(completion_call.input_ids)])) @@ -471,6 +484,10 @@ async def token_stream(): continuation.continuation_type: continuation } } + + if "sample-id" in sampling_mode: + user_data["dc-edge-type"] = sampling_mode + # print("token stream gives", result_id, tokens, scores, edge_type, flush=True) scores = [0.0 if str(s) == "[]" else s for s in scores] @@ -513,7 +530,15 @@ async def sample(self, sequences, num_samples=1, **kwargs): kwargs = {**self.model_args, **kwargs} async def op_sample(seqs): - completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=num_samples, **kwargs) + temperature = kwargs.get("temperature", 1.0) + if temperature == 0.0: + sampling_modes = ["top-1" for _ in range(len(seqs))] + edge_type_populated_user_data = [{} for _ in range(len(seqs))] + else: + sampling_modes = [f"sample-{temperature}-sample-id-{random.randint(0, 2**32-1)}" for _ in range(len(seqs))] + edge_type_populated_user_data = [{"dc-edge-type": sm} for sm in sampling_modes] + + completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=num_samples, sampling_modes=sampling_modes, **kwargs) next_token_ids = [] next_token_scores = [] @@ -583,8 +608,8 @@ async def op_sample(seqs): next_token_ids = token_ids next_token_scores = next_token_scores - def successor_user_data(continuation_buffer: SequenceResult, num_successors): - default_user_data = {} + def successor_user_data(continuation_buffer: SequenceResult, num_successors, user_data): + default_user_data = {**user_data} if continuation_buffer.continuation_type is None: return [default_user_data.copy()] * num_successors continuation_as_user_data = { @@ -596,7 +621,7 @@ def successor_user_data(continuation_buffer: SequenceResult, num_successors): return [continuation_as_user_data] + [default_user_data.copy()] * (num_successors - 1) return [s.make_successors(next_token_ids[i], next_token_scores[i], logits=logits[i], - user_data=successor_user_data(continuation_buffers[i], len(next_token_ids[i]))) for i,s in enumerate(seqs)] + user_data=successor_user_data(continuation_buffers[i], len(next_token_ids[i]), edge_type_user_data)) for i, s, edge_type_user_data in zip(range(len(seqs)), seqs, edge_type_populated_user_data)] with self.stats.timer("sample"): return await sequences.aelement_wise(op_sample) @@ -988,20 +1013,24 @@ def cost_estimate(self, model): return openai.AsyncConfiguration.get_stats().cost_estimate(model) def openai_model(model_identifier, endpoint=None, mock=False, **kwargs) -> ModelAPIAdapter: - class OpenAIModel(ModelAPIAdapter): + class OpenAIAPIAdapter(ModelAPIAdapter): def __init__(self) -> None: self.model_identifier = model_identifier self.served_model = None self._tokenizer = None + + self.tokenizer_identifier = kwargs.pop("tokenizer", model_identifier) + if self.tokenizer_identifier.startswith("openai/"): + self.tokenizer_identifier = self.tokenizer_identifier.split("openai/",1)[1] self.decoder_args = {} def get_tokenizer(self): if self._tokenizer is None: if not mock: - self._tokenizer = load_tokenizer("tiktoken:" + self.model_identifier) + self._tokenizer = tokenizer("tiktoken:" + self.tokenizer_identifier) else: - self._tokenizer = load_tokenizer(self.model_identifier) + self._tokenizer = tokenizer(self.tokenizer_identifier) self.served_model = self return self._tokenizer @@ -1020,4 +1049,11 @@ async def detokenize(self, input_ids): def sync_tokenize(self, text): return self.get_tokenizer()(text)["input_ids"] - return OpenAIModel() \ No newline at end of file + + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f"" + + return OpenAIAPIAdapter() \ No newline at end of file diff --git a/src/lmql/runtime/tokenizer.py b/src/lmql/runtime/tokenizer.py index 9daa650c..11db895a 100644 --- a/src/lmql/runtime/tokenizer.py +++ b/src/lmql/runtime/tokenizer.py @@ -55,6 +55,12 @@ def load(): if "FORCE_TIKTOKEN" in os.environ: assert type(self.tokenizer_impl) is TiktokenTokenizer + def __str__(self): + return "".format(self.model_identifier) + + def __repr__(self): + return str(self) + @property def model_vocab_size(self): """ @@ -70,7 +76,7 @@ def tokenizer_impl(self): if self._tokenizer_impl is None: self.loader_thread.join() if self._tokenizer_impl is None: - raise TokenizerNotAvailableError("Failed to load suitable tokenizer for model '{}'".format(self.model_identifier)) + raise TokenizerNotAvailableError("Failed to derive a suitable tokenizer from the provided model name '{}'. If your model requires a specific (well-known) tokenizer, make sure specify it via lmql.model(..., tokenizer='...').".format(self.model_identifier)) return self._tokenizer_impl @property @@ -262,7 +268,12 @@ def chunk_out_by_tags(self, s, tokenize=True): return segments -def load_tokenizer(model_identifier, type="auto", **kwargs): +def tokenizer(model_identifier, type="auto", **kwargs) -> LMQLTokenizer: + """ + Loads a LMQLTokenizer for the given model identifier. + + If type is 'auto', the tokenizer will be loaded from the most suitable available backend. Otherwise, the type can be one of 'hf' (huggingface transformers), 'tiktoken' (tiktoken) or 'auto' (default). + """ cache_identifier = model_identifier.replace("/", "-").replace(":", "__") cache_path = f"tokenizer-{cache_identifier}.pkl" @@ -272,7 +283,7 @@ def load_tokenizer(model_identifier, type="auto", **kwargs): model_identifier = model_identifier[len("tiktoken:"):] # map gpt-3.5-turbo* to gpt-3.5-turbo - if "turbo" in model_identifier: + if "3.5" in model_identifier and "turbo" in model_identifier: tiktoken_identifier = "gpt-3.5-turbo" else: tiktoken_identifier = model_identifier @@ -350,7 +361,7 @@ def get_vocab(tokenizer): import torch model_identifier = sys.argv[1] - t = load_tokenizer(model_identifier) + t = tokenizer(model_identifier) to_tokenize = sys.argv[2] diff --git a/src/lmql/tests/tail_token_set.py b/src/lmql/tests/tail_token_set.py index 06b3448c..36b8c500 100644 --- a/src/lmql/tests/tail_token_set.py +++ b/src/lmql/tests/tail_token_set.py @@ -1,8 +1,8 @@ +import lmql from lmql.ops.token_set import * -from lmql.runtime.tokenizer import load_tokenizer from lmql.tests.expr_test_utils import run_all_tests -VocabularyMatcher.init(load_tokenizer("gpt2")) +VocabularyMatcher.init(lmql.tokenizer("gpt2")) def test_simple(): p1 = ntset("eos") diff --git a/src/lmql/tests/test_eq.py b/src/lmql/tests/test_eq.py index 190d3120..55311bed 100644 --- a/src/lmql/tests/test_eq.py +++ b/src/lmql/tests/test_eq.py @@ -1,8 +1,7 @@ import lmql from lmql.tests.expr_test_utils import run_all_tests -from lmql.runtime.tokenizer import load_tokenizer -t = load_tokenizer("gpt2") +t = lmql.tokenizer("gpt2") def token_diff(s1, s2): ids1 = t(s1)["input_ids"] diff --git a/src/lmql/tests/test_multi_tokenizer.py b/src/lmql/tests/test_multi_tokenizer.py index 1295fb8f..5cca3d4a 100644 --- a/src/lmql/tests/test_multi_tokenizer.py +++ b/src/lmql/tests/test_multi_tokenizer.py @@ -37,7 +37,7 @@ async def test_llama_from_gpt(): @lmql.query(model="chatgpt") async def cg(): '''lmql - "Hello[WORLD]" where len(TOKENS(WORLD)) < 4 + "Hello[WORLD]" where len(TOKENS(WORLD)) < 3 return WORLD ''' @@ -46,8 +46,8 @@ async def test_gpt35(): '''lmql "Hello[WORLD]" where len(TOKENS(WORLD)) == 4 r = [WORLD, cg()] - assert r == [", I am a", " Hello!"], "Expected {}, got {}".format( - [", I am a", "Hello!"], + assert r == [", I am a", " Hello!"], "Expected {}, got {}".format( + [", I am a", " Hello!"], r ) return WORLD diff --git a/src/lmql/tests/test_noprompt.py b/src/lmql/tests/test_noprompt.py index b16cc105..f89867d9 100644 --- a/src/lmql/tests/test_noprompt.py +++ b/src/lmql/tests/test_noprompt.py @@ -1,8 +1,7 @@ import lmql from expr_test_utils import run_all_tests -from lmql.runtime.tokenizer import load_tokenizer -t = load_tokenizer("gpt2") +t = lmql.tokenizer("gpt2") def token_diff(s1, s2): ids1 = t(s1)["input_ids"] diff --git a/src/lmql/tests/test_sample_queries.py b/src/lmql/tests/test_sample_queries.py index fdcf707e..e8413042 100644 --- a/src/lmql/tests/test_sample_queries.py +++ b/src/lmql/tests/test_sample_queries.py @@ -9,7 +9,6 @@ import io import termcolor -from lmql.runtime.tokenizer import load_tokenizer from lmql.runtime.stats import Stats # load queries by executing ../ui/playground/src/queries.js via node and getting the object of module.exports @@ -39,7 +38,7 @@ async def main(): queries = load_queries() stderr = sys.stderr - print("\nTokenizer Backend: ", type(load_tokenizer("text-davinci-003").tokenizer_impl).__name__, "\n") + print("\nTokenizer Backend: ", type(lmql.tokenizer("text-davinci-003").tokenizer_impl).__name__, "\n") api_stats = Stats("openai-api") diff --git a/src/lmql/tests/tiktoken_tsets.py b/src/lmql/tests/tiktoken_tsets.py index fc20dd90..610aaa63 100644 --- a/src/lmql/tests/tiktoken_tsets.py +++ b/src/lmql/tests/tiktoken_tsets.py @@ -1,14 +1,14 @@ +import lmql from lmql.ops.token_set import * -from lmql.runtime.tokenizer import load_tokenizer from lmql.runtime.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from lmql.tests.expr_test_utils import run_all_tests -t = load_tokenizer("text-davinci-003") +t = lmql.tokenizer("text-davinci-003") assert type(t.tokenizer_impl) is TiktokenTokenizer VocabularyMatcher.init(t) def test_simple(): - t = load_tokenizer("text-davinci-003") + t = lmql.tokenizer("text-davinci-003") # simple eos # s = tset("eos")