Skip to content

Commit

Permalink
* more liberal distribution parsing
Browse files Browse the repository at this point in the history
* improved sample-consistent openai caching
* extended output writer
* lmql.tokenizer
* Chat API use correct tokenizer
* allow OpenAI tokenizer configuration
  • Loading branch information
lbeurerkellner committed Oct 1, 2023
1 parent 9b23d0f commit defd690
Show file tree
Hide file tree
Showing 20 changed files with 189 additions and 82 deletions.
1 change: 1 addition & 0 deletions src/lmql/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .scoring import ScoringResult
from .serve import serve
from inspect import *
from lmql.runtime.tokenizer import tokenizer

async def generate(prompt: str, max_tokens: Optional[int] = None, model: Optional[Union[LLM, str]] = None, **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/lmql/api/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def argmax(self, agg="sum") -> str:

def __str__(self):
return "lmql.ScoringResult(model='{}')\n".format(self.model_identifier) + \
"\n".join([f"-{c}: {score}" for c,score in zip(self.continuations, self.scores(agg="sum"))])
"\n".join([f"-{str([c])[1:-1]}: {score}" for c,score in zip(self.continuations, self.scores(agg="sum"))])

async def dc_score(model: dc.DcModel, prompt, values, **kwargs):
"""
Expand Down
29 changes: 25 additions & 4 deletions src/lmql/language/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lmql.runtime.lmql_runtime as lmql_runtime
from lmql.language.fragment_parser import (FragmentParserError,
LanguageFragmentParser,
LMQLDistributionClause,
double_unescape_str,
LMQLDecoderConfiguration, LMQLQuery)
from lmql.language.qstrings import (DistributionVariable, FExpression,
Expand Down Expand Up @@ -90,9 +91,10 @@ def __init__(self):
self.prologue_vars = set()
self.free_vars = set()
self.written_vars = set()

self.defined_constraints = set()

self.query = None

def scope_prologue(self, query: LMQLQuery):
if query.prologue is None: return

Expand All @@ -111,6 +113,8 @@ def scope(self, query: LMQLQuery):
# collect defined vars in prologue
self.scope_prologue(query)

self.query = query

# collect defined vars in prompt
for p in query.prompt:
self.visit(p)
Expand Down Expand Up @@ -144,6 +148,16 @@ def visit_BoolOp(self, node: ast.BoolOp) -> Any:
self.scope_Constant(node.values[0])
for constraint in node.values[1:]:
self.visit_where(constraint)
elif is_query_string_with_distribution(node):
assert len(node.values) == 2, "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'"
distribution_in_clause = node.values[1]
assert isinstance(distribution_in_clause, ast.Compare), "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'"
var = distribution_in_clause.left
assert isinstance(var, ast.Name), "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'"
self.distribution_vars = set([var.id])
assert len(distribution_in_clause.comparators) == 1, "compiler error: distribution clause must be an expression of shape 'distribution VAR in [val1, val2, ...]'"
self.query.distribution = LMQLDistributionClause(var.id, distribution_in_clause.comparators[0])
self.scope_Constant(node.values[0])
else:
super().generic_visit(node)

Expand Down Expand Up @@ -276,6 +290,12 @@ def is_query_string_with_constraints(node: ast.BoolOp):
left_most_operand = node.values[0]
return type(left_most_operand) is ast.Constant and type(left_most_operand.value) is str and isinstance(node.op, ast.And)

def is_query_string_with_distribution(node: ast.BoolOp):
if len(node.values) < 1:
return False
left_most_operand = node.values[0]
return type(left_most_operand) is ast.Constant and type(left_most_operand.value) is str and isinstance(node.op, ast.Or)

def attr(s):
names = s.split(".")
element = ast.Name(names[0], ast.Load())
Expand Down Expand Up @@ -357,6 +377,9 @@ def visit_BoolOp(self, node: ast.BoolOp) -> Any:
elif len(node.values[1:]) == 0:
constraints_expression = None
return self.transform_Constant(left_most_operand, constraints = constraints_expression)
elif is_query_string_with_distribution(node):
left_most_operand = node.values[0]
return self.transform_Constant(left_most_operand)
return self.generic_visit(node)

def visit_FunctionDef(self, node: FunctionDef) -> Any:
Expand Down Expand Up @@ -934,6 +957,4 @@ def compile(self, filepath):

return LMQLModule(output_file, lmql_code=lmql_code, output_variables=[v for v in scope.defined_vars])
except FragmentParserError as e:
sys.stderr.write("error: " + str(e) + "\n")
sys.exit(1)

raise RuntimeError("parsing error: {}.\nFailed when parsing:\n {}".format(e, lmql_code))
19 changes: 14 additions & 5 deletions src/lmql/language/fragment_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def parse(self, readline):

self.prologue_transform()
self.inline_where_transform()
self.inline_distribution_transform()
self.ast_parse()
self.syntax_validation()
self.ast_transform()
Expand All @@ -177,7 +178,14 @@ def inline_where_transform(self):
lookahead = prompt_tokens[i+1]
if tok.type == tokenize.STRING and lookahead.type == tokenize.NAME and lookahead.string == "where":
prompt_tokens[i+1] = tokenize.TokenInfo(type=tokenize.OP, string="and", start=lookahead.start, end=lookahead.end, line=lookahead.line)


def inline_distribution_transform(self):
prompt_tokens = self.query.prompt_str
for i in range(len(prompt_tokens) - 1):
tok = prompt_tokens[i]
lookahead = prompt_tokens[i+1]
if tok.type == tokenize.STRING and lookahead.type == tokenize.NAME and lookahead.string == "distribution":
prompt_tokens[i+1] = tokenize.TokenInfo(type=tokenize.OP, string="or", start=lookahead.start, end=lookahead.end, line=lookahead.line)

def prologue_transform(self):
# translate prologue tokens into str
Expand Down Expand Up @@ -251,7 +259,7 @@ def digest(self, tok):
self.state = "decode"
return

if is_keyword(tok, "where"):
if is_keyword(tok, "where") or is_keyword(tok, "distribution"):
self.query.prompt_str = self.query.prologue + [tok]
self.query.prologue = []
self.state = "prompt"
Expand Down Expand Up @@ -279,16 +287,17 @@ def digest(self, tok):
if self.query.prompt_str[-1].type != tokenize.STRING:
self.state = "where"
return

if is_keyword(tok, "FROM"):
self.state = "from"
return
if is_keyword(tok, "SCORING"):
self.state = "scoring"
return
if is_keyword(tok, "DISTRIBUTION"):
self.state = "distribution"
return
if self.query.prompt_str[-1].type != tokenize.STRING:
self.state = "distribution"
return

# if last token is NAME and current is str
if len(self.query.prompt_str) > 0 and self.query.prompt_str[-1].type == tokenize.NAME and \
Expand Down
4 changes: 2 additions & 2 deletions src/lmql/models/lmtp/backends/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from lmql.models.lmtp.backends import LMTPModel
from lmql.models.lmtp.lmtp_scheduler import TokenStreamer
from lmql.runtime.tokenizer import load_tokenizer
import lmql

import transformers

# simple 'main' for testing backends
if __name__ == "__main__":
backend = sys.argv[1]
model: LMTPModel = LMTPModel.load(backend)
t = load_tokenizer("huggyllama/llama-7b")
t = lmql.tokenizer("huggyllama/llama-7b")

s = sys.argv[2]
input_ids = [t.bos_token_id] + t(s)["input_ids"]
Expand Down
4 changes: 2 additions & 2 deletions src/lmql/models/lmtp/lmtp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ async def interactive_client():

async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://workstation:8888') as ws:
from lmql.runtime.tokenizer import load_tokenizer
from lmql.runtime.tokenizer import tokenizer

model = sys.argv[1]
tokenizer = load_tokenizer(model)
tokenizer = tokenizer(model)

client = LMTPWebSocketClient(model, ws)
client.connect()
Expand Down
4 changes: 2 additions & 2 deletions src/lmql/models/lmtp/lmtp_dcmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from lmql.runtime.dclib.dclib_model import DcModel
from lmql.runtime.tokenizer import load_tokenizer
from lmql.runtime.tokenizer import tokenizer
from .lmtp_async import LMTPAsyncClient
import lmql.runtime.dclib as dc
import asyncio
Expand Down Expand Up @@ -496,7 +496,7 @@ def __init__(self) -> None:

def get_tokenizer(self):
if self._tokenizer is None:
self._tokenizer = load_tokenizer(this.tokenizer_identifier, **this.kwargs)
self._tokenizer = tokenizer(this.tokenizer_identifier, **this.kwargs)
self.served_model = self
return self._tokenizer

Expand Down
6 changes: 3 additions & 3 deletions src/lmql/models/lmtp/lmtp_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema import LLMResult

from lmql.runtime.tokenizer import LMQLTokenizer, load_tokenizer
from lmql.runtime.model_registry import LMQLModelRegistry
import lmql
from lmql.runtime.tokenizer import LMQLTokenizer

if TYPE_CHECKING:
from tenacity import RetryCallState
Expand Down Expand Up @@ -232,7 +232,7 @@ def _get_params(self, _kwarg_dict: Dict[str, Any]) -> Any:

def _get_tokenizer(self) -> LMQLTokenizer:
if self.tokenizer is None:
self.tokenizer = LMQLModelRegistry.get(self.model).get_tokenizer()
self.tokenizer = lmql.model(self.model).get_tokenizer()
return self.tokenizer

def _call(
Expand Down
22 changes: 11 additions & 11 deletions src/lmql/runtime/bopenai/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import time
import asyncio

from lmql.runtime.tokenizer import load_tokenizer
from lmql.runtime.stats import Stats
from lmql.models.model_info import model_info

Expand Down Expand Up @@ -72,13 +71,7 @@ async def complete(**kwargs):
global tokenizers
tokenizers = {}

def tokenize(text, model, openai_byte_encoding=False):
global tokenizers
if not model in tokenizers:
tokenizer = load_tokenizer("tiktoken:" + model)
tokenizers[model] = tokenizer
else:
tokenizer = tokenizers[model]
def tokenize(text, tokenizer, openai_byte_encoding=False):
ids = tokenizer(text)["input_ids"]
raw = tokenizer.decode_bytes(ids)
if openai_byte_encoding:
Expand Down Expand Up @@ -175,9 +168,12 @@ async def chat_api(**kwargs):
num_prompts = len(kwargs["prompt"])
max_tokens = kwargs.get("max_tokens", 0)
model = kwargs["model"]
api_config = kwargs.get("api_config", {})
tokenizer = api_config.get("tokenizer")
assert tokenizer is not None, "internal error: chat_api expects an 'api_config' with a 'tokenizer: LMQLTokenizer' mapping in your API payload"

assert "logit_bias" not in kwargs.keys(), f"Chat API models do not support advanced constraining of the output, please use no or less complicated constraints."
prompt_tokens = tokenize(kwargs["prompt"][0], model=model, openai_byte_encoding=True)
prompt_tokens = tokenize(kwargs["prompt"][0], tokenizer=tokenizer, openai_byte_encoding=True)

timeout = kwargs.pop("timeout", 1.5)

Expand Down Expand Up @@ -229,6 +225,8 @@ async def chat_api(**kwargs):
del kwargs["prompt"]
kwargs["messages"] = messages

needs_space = True # messages[-1]["content"][-1] != " "

del kwargs["logprobs"]

async with CapacitySemaphore(num_prompts * max_tokens):
Expand All @@ -237,7 +235,6 @@ async def chat_api(**kwargs):
stream_start = time.time()

async with aiohttp.ClientSession() as session:
api_config = kwargs.get("api_config", {})
endpoint, headers = get_endpoint_and_headers(kwargs)

if api_config.get("verbose", False) or os.environ.get("LMQL_VERBOSE", "0") == "1" or api_config.get("chatty_openai", False):
Expand Down Expand Up @@ -327,7 +324,10 @@ async def chunk_timer():
})
continue
text = delta["content"]
tokens = tokenize((" " if received_text == "" else "") + text, model=model, openai_byte_encoding=True)
if len(text) == 0:
continue

tokens = tokenize((" " if received_text == "" and needs_space else "") + text, tokenizer=tokenizer, openai_byte_encoding=True)
received_text += text

# convert tokens to OpenAI format
Expand Down
16 changes: 12 additions & 4 deletions src/lmql/runtime/dclib/dclib_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,18 @@ async def get_mask(self, s: DecoderSequence, **kwargs):

async def get_keys(self, s: DecoderSequence, edge_type: str, **kwargs):
kwargs = {**self.delegate.model_args, **kwargs}
reuse_context = kwargs.get("cache_reuse_context", set())

keys = []

# check for sample-id
if s.data("dc-edge-type"):
if s.data("dc-edge-type") and edge_type is not None:
dc_edge_type = s.data("dc-edge-type")
# if the edge type aligns with dc-edge-type, use that instead (includes a unique sample id if available)
if s.data("dc-edge-type").startswith(edge_type):
edge_type = s.data("dc-edge-type")
if dc_edge_type.startswith(edge_type):
if not dc_edge_type in reuse_context:
reuse_context.add(dc_edge_type)
edge_type = dc_edge_type

# compute logits mask
mask = (await self.get_mask(s, **kwargs)).logits_mask[0]
Expand Down Expand Up @@ -269,7 +273,11 @@ async def op_sample(seqs):
temperature = kwargs.get('temperature', 1.0)
sampling_mode = "top-1" if temperature == 0.0 else "sample-{}".format(temperature)

cache_entries = [await self.get_cache(s, sampling_mode, user_data=True, **kwargs) for s in seqs]
# make sure that each uniquely sampled trajectory in the cache, cannot be used
# twice as a result of sampling (e.g. when sampling multiple times from the same sequence)
cache_reuse_context = set()

cache_entries = [await self.get_cache(s, sampling_mode, user_data=True, cache_reuse_context=cache_reuse_context, **kwargs) for s in seqs]
cached_cont = [e[1] for e in cache_entries]
cache_keys = [e[0] for e in cache_entries]

Expand Down
1 change: 0 additions & 1 deletion src/lmql/runtime/dclib/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from typing import List, Any, Union, Optional, Dict

from lmql.runtime.tokenizer import load_tokenizer
from lmql.runtime.dclib.dclib_array import DataArray, sum_scorer, alpha_length_normalized, alpha_length_normalized_det
from lmql.runtime.dclib.dclib_seq import next_is_deterministic
import lmql.runtime.dclib as dc
Expand Down
Loading

0 comments on commit defd690

Please sign in to comment.