diff --git a/docs/source/lib/generations.md b/docs/source/lib/generations.md
index 8bf76736..e6ac0657 100644
--- a/docs/source/lib/generations.md
+++ b/docs/source/lib/generations.md
@@ -1,6 +1,12 @@
-# Generations API
+# Generations API NEW
-The *Generations API* is a small library with the goal of providing high-level access to LMQL core features, such as inference backends, constrained generation, and advanced caching and scoring. The Generations API is designed to be easy to use, and does not require users to write any LMQL themselves.
+
A simple Python API for LMQL-based text generation and scoring.
+
+The *Generations API* is a lightweight library with the goal of providing high-level access to LMQL core features, such as various inference backends, constrained generation, and advanced caching and scoring. The API was designed to be easy to use and does not require users to write any LMQL themselves.
+
+
+
+## Overview
To illustrate the Generations API, let's look at a simple example of generating and scoring text using the `openai/gpt-3.5-turbo-instruct` model:
@@ -28,120 +34,129 @@ The snippet above demonstrates the different components of the Generations API:
- [**`lmql.LLM.generate(...)`**](#lmql-generate) is a simple function to generating text completions based on a given prompt. This can be helpful to quickly obtain single-step completions, or to generate a list of completions for a given prompt.
-- [**`lmql.LLM.score(...)`**](#lmql-score) allows you to directly access the scores, your model assigns to the tokenized representation of your input prompt and continuations. This can be helpful for tasks such as classification or ranking.
+- [**`lmql.LLM.score(...)`**](#lmql-score) allows you to directly access the scores, your model assigns to the tokenized representation of your input prompt and continuations. This can be helpful for tasks such as classification or ranking.
+
+ The result is an [`lmql.ScoringResult`](https://github.com/eth-sri/lmql/blob/main/src/lmql/api/scoring.py) object, which contains the scores for each continuation, as well as the prompt and continuations used for scoring. It provides a convenient interface for score aggregation, normalization and maximum selection.
**Compatibility** For more advanced use cases, the Generation API seamlessly blends with standard LMQL, allowing users to gradually adopt the full language runtime over time, if their use cases require it.
-**Implementation** Internally, the Generations API is implemented as a thin wrapper around LMQL, and thus benefits from all the features of LMQL, such as caching, parallelization, and more. The API is fully asynchronous, and can be used with any async framework, such as `asyncio`, `trio`, or `curio`. Alternatively, the API can also be used synchronously, using the `*_sync` variants of the API functions.
+**Implementation** Internally, the Generations API is implemented as a thin wrapper around LMQL, and thus benefits from all the features of LMQL, such as caching, parallelization, and more. The API is fully asynchronous, and should be used with `asyncio`. Alternatively, all API funcationality is also available synchronously, using the `*_sync` variants of the functions.
-## API Reference
+## `lmql.LLM` Objects
-The Generation API is available directly in the top-level namespace of the `lmql` module:
+At the core, `lmql.LLM` objects represent a specific language model and provide methods for generation and scoring. An `lmql.LLM` is instantiated using `lmql.model(...)` and can be passed [as-is to LMQL query programs](../language/models.rst#loading-models) or to the top-level [`lmql.generate`](#lmql-generate) and [`lmql.score`](#lmql-score) functions.
-### `lmql.generate(...)`
+### `LLM.generate(...)`
```python
-async def lmql.generate(
+async def generate(
+ self,
prompt: str,
max_tokens: Optional[int] = None,
- model: Optional[Union[LLM, str]] = None,
+ decoder: str,
**kwargs
) -> Union[str, List[str]]
```
-`lmql.generate` generates text based on a given prompt, using a given model as the generation backend.
+Generates a text completion based on a given prompt. Returns the full prompt + completion as one string.
**Arguments**
-- `prompt`: The prompt to generate from.
-- `max_tokens`: The maximum number of tokens to generate. If `None`, text is generated until the model returns an *end-of-sequence* token.
-- `model`: The model to use for generation. If `None`, the default model is used.
-- `**kwargs`: Additional keyword arguments are passed to the underlying LMQL query program. This can be used to specify the `decoder`, options like `chunksize` or `n`, or any other model or decoder arguments.
+- `prompt: str`: The prompt to generate from.
+- `max_tokens: Optional[int]`: The maximum number of tokens to generate. If `None`, text is generated until the model returns an *end-of-sequence* token.
+- `decoder: str`: The [decoding algorithm](../language/decoders.md) to use for generation. Defaults to `"argmax"`.
+- `**kwargs`: Additional keyword arguments that are passed to the underlying LMQL query program. These can be useful to specify options like `chunksize`, decoder arguments like `n`, or any other model or decoder-specific arguments.
+
+**Return Value** The function returns a string or a list of strings, depending on the decoder in use (`decoder=argmax` yields a single sequence, `decoder="sample", n=2` yields two sequences, etc.).
-**Return Value** The function returns a string or a list of strings, depending on the decoder in use (`argmax` yields a single sequence, `decoder="sample", n=2` yields two sequences, etc.).
+**Asynchronous** The function is asynchronous and should be used with [`asyncio`](https://docs.python.org/3/library/asyncio.html) and with `await`. When run in parallel, multiple generations will be batched and parallelized across multiple calls to the same model. For synchronous use, you can rely on [`LLM.generate_sync`](#llm-generate_sync), note however, that in this case, the function will block the current thread until generation is complete.
+### `LLM.generate_sync(...)`
-### `lmql.generate_sync(...)`
+```python
+def generate_sync(self, *args, **kwargs):
+```
-Synchronous version of `lmql.generate`.
+Synchronous version of [`lmql.LLM.generate`](#llm-generate).
-### `lmql.score`
+### `LLM.score(...)`
```python
async def score(
+ self,
prompt: str,
- values: Union[str, List[str]],
- model: Optional[Union[str, LLM]] = None,
- **kwargs
+ values: Union[str, List[str]]
) -> lmql.ScoringResult
```
-`lmql.score` scores different continuation `values` for a given `prompt`.
+Scores different continuation `values` for a given `prompt`.
-**Arguments**
+For instance `await m.score("Hello", ["World", "Apples", "Oranges"])` would score the continuations `"Hello World"`, `"Hello Apples"` and `"Hello Oranges"`.
-- `prompt`: The prompt to score from.
-- `values`: The values to score.
-- `model`: The model to use for scoring. If `None`, the default model is used.
+**Arguments**
-**Return Value** The function returns an `lmql.ScoringResult` object, which contains the scores for each value, as well as the prompt and values used for scoring.
+- `prompt`: The prompt to use as a common prefix for all continuations.
+- `values`: The continuation values to score. This can be a single string or a list of strings.
-### `lmql.score_sync(...)`
+**Return Value** The result is an [`lmql.ScoringResult`](https://github.com/eth-sri/lmql/blob/main/src/lmql/api/scoring.py) object, which contains the scores for each continuation, as well as the prompt and continuations used for scoring. It provides a convenient interface for score aggregation, normalization and maximum selection.
-Synchronous version of `lmql.score`.
+**Asynchronous** The function is asynchronous and should be used with [`asyncio`](https://docs.python.org/3/library/asyncio.html) and with `await`. When run in parallel, multiple generations will be batched and parallelized across multiple calls to the same model. For synchronous use, you can rely on [`LLM.score_sync`](#llm-score-sync).
-### `lmql.set_default_model(...)`
+### `LLM.score_sync(...)`
```python
-def set_default_model(model: Union[str, LLM])
+def score_sync(self, *args, **kwargs)
```
-Sets the model instance to be used when no 'from' clause or @lmql.query(model=) are specified.
-
-This applies globally in the current process.
+Synchronous version of [`lmql.LLM.score`](#llm-score).
-### `lmql.LLM` Objects
-`lmql.LLM` objects represent a specific model, and provide methods for generation and scoring. An `lmql.LLM` is instantiated using `lmql.model(...)` and can be passed as-is to LMQL query programs (in the `from` clause) or to the `lmql.generate` and `lmql.score` functions.
-
-```python
-def get_tokenizer(self) -> LMQLTokenizer
-```
+The Generation API is available directly in the top-level namespace of the `lmql` module:
-Returns the tokenizer used by the model.
+## `lmql.generate(...)`
```python
-async def generate(
- self,
+async def lmql.generate(
prompt: str,
max_tokens: Optional[int] = None,
+ model: Optional[Union[LLM, str]] = None,
**kwargs
) -> Union[str, List[str]]
```
-Model-bound version of [`lmql.generate`](#lmql-generate).
+`lmql.generate` generates text completions based on a given prompt and behaves just like [`LLM.generate`](#llm-generate),
+with the provided `model` instance or model name.
-```python
-def generate_sync(self, *args, **kwargs):
-```
+## `lmql.generate_sync(...)`
-Synchronous version of `lmql.LLM.generate`.
+Synchronous version of [`lmql.generate`](#lmql-generate).
+
+## `lmql.score(...)`
```python
async def score(
- self,
prompt: str,
values: Union[str, List[str]],
+ model: Optional[Union[str, LLM]] = None,
**kwargs
) -> lmql.ScoringResult
```
-Model-bound version of [`lmql.score`](#lmql-score).
+`lmql.score` scores different continuation `values` for a given `prompt` and behaves just like [`LLM.score`](#llm-score),
+with the provided `model` instance or model name.
+
+## `lmql.score_sync(...)`
+
+Synchronous version of [`lmql.score`](#lmql-score).
+
+## `lmql.set_default_model(...)`
```python
-def score_sync(self, *args, **kwargs)
+def set_default_model(model: Union[str, LLM])
```
-Synchronous version of `lmql.LLM.score`.
+Sets the model instance to be used when no 'from' clause or @lmql.query(model=) are specified.
+
+This applies globally in the current process.
diff --git a/src/lmql/api/scoring.py b/src/lmql/api/scoring.py
index 84f1b395..3ac00029 100644
--- a/src/lmql/api/scoring.py
+++ b/src/lmql/api/scoring.py
@@ -14,16 +14,23 @@ class ScoringResult:
Provides methods to aggregate scores and return the best continuation.
"""
- def __init__(self, prompt, continuations, seqs: List[dc.seq], model_identifier: str):
+ def __init__(self, prompt, continuations: List[str], num_value_tokens: List[int], seqs: List[dc.seq], model_identifier: str):
self.seqs = [s.expand() for s in seqs]
self.prompt = prompt
+ # the continuations that were scored
self.continuations = continuations
+ # per continuation, the number of tokens that originate from the appended continuation value
+ self.num_value_tokens = num_value_tokens
self.model_identifier = model_identifier
@property
- def token_scores(self):
+ def full_token_scores(self):
return [s.logprobs for s in self.seqs]
+ @property
+ def token_scores(self):
+ return [s.logprobs[-self.num_value_tokens[i]:] for i,s in enumerate(self.seqs)]
+
def scores(self, agg="sum", **kwargs):
"""
Returns the sequence scores per continuation.
@@ -83,10 +90,11 @@ async def dc_score(model: dc.DcModel, prompt, values, **kwargs):
prompt_seq = dc.seq(model.tokenizer.tokenize(prompt, asbytes=True))
value_ids = [model.tokenizer.tokenize(value, asbytes=True) for value in values]
+ num_value_ids = [len(ids) for ids in value_ids]
kwargs.pop("internal", None)
all_tokens = []
all_scores = []
kwargs["noscore"] = False
- return ScoringResult(prompt, values, await model.score([prompt_seq] * len(value_ids), value_ids, **kwargs), model.model_identifier)
\ No newline at end of file
+ return ScoringResult(prompt, values, num_value_ids, await model.score([prompt_seq] * len(value_ids), value_ids, **kwargs), model.model_identifier)
\ No newline at end of file
diff --git a/src/lmql/runtime/interpreter.py b/src/lmql/runtime/interpreter.py
index 271916ee..81563dae 100644
--- a/src/lmql/runtime/interpreter.py
+++ b/src/lmql/runtime/interpreter.py
@@ -190,11 +190,11 @@ async def get_return_value(self, *args):
return LMQLResult(self.state.prompt, await self.get_all_vars(),self.interpreter.distribution_variable, self.interpreter.distribution_values)
- async def score(self, *args, **kwargs):
+ async def score(self, values, **kwargs):
model = kwargs.get("model", None)
if model is not None:
- return await score(self.prompt, *args, **kwargs)
- return await dc_score(self.interpreter.dcmodel, self.prompt, *args, **kwargs)
+ return await score(self.prompt, values, **kwargs)
+ return await dc_score(self.interpreter.dcmodel, self.prompt, values, **kwargs)
@dataclass
class LMQLResult:
@@ -314,6 +314,11 @@ def set_model(self, model_handle: Union[str, LLM]):
self.model = model_handle
+ # prepare dcmodel
+ decoder_args = self.decoder_kwargs
+ self.model.adapter.decoder_args = {**decoder_args, **self.extra_kwargs}
+ self.dcmodel: dc.DcModel = self.model.adapter.get_dclib_model()
+
async def advance(self, state: PromptState):
if state.variable is not None:
return state
@@ -895,11 +900,6 @@ async def run(self, fct, *args, **kwargs):
tail=None)
self.root_state = await self.advance(self.root_state)
- # prepare dcmodel
- decoder_args = self.decoder_kwargs
- self.model.adapter.decoder_args = {**decoder_args, **self.extra_kwargs}
- self.dcmodel: dc.DcModel = self.model.adapter.get_dclib_model()
-
async def debug_out(decoder_step):
if PromptInterpreter.main != self:
return
@@ -934,7 +934,7 @@ async def debug_out(decoder_step):
# make sure that the initial prompt is not considered part of a variable
self.root_state = self.root_state.updated(variable_offset=n)
- decoder_args = decoder_args.copy()
+ decoder_args = self.decoder_kwargs.copy()
# pass processor as decoder argument
decoder_args["modern_logits_processor"] = self.where_processor
diff --git a/src/lmql/tests/expr_test_utils.py b/src/lmql/tests/expr_test_utils.py
index c09f3794..638c6277 100644
--- a/src/lmql/tests/expr_test_utils.py
+++ b/src/lmql/tests/expr_test_utils.py
@@ -166,6 +166,7 @@ def run_all_tests(g):
tb = traceback.extract_tb(e.__traceback__)
tb = "\n".join(traceback.format_list(tb))
termcolor.cprint("FAILED\n{}".format(tb), "red")
+ termcolor.cprint("AssertionError: {}".format(e), "red")
print(e)
# wait for all tasks to finish
diff --git a/src/lmql/tests/test_api.py b/src/lmql/tests/test_api.py
index b41e6dc9..7c95051b 100644
--- a/src/lmql/tests/test_api.py
+++ b/src/lmql/tests/test_api.py
@@ -1,4 +1,6 @@
import lmql
+import numpy as np
+
from lmql.tests.expr_test_utils import run_all_tests
def test_generate_sync():
@@ -32,7 +34,8 @@ def test_score_sync():
assert type(result) is lmql.ScoringResult
assert len(result.seqs) == 1
- assert len(result.token_scores) == 1 and result.token_scores[0].shape == (2,)
+ assert len(result.token_scores) == 1 and result.token_scores[0].shape == (1,)
+ assert len(result.full_token_scores) == 1 and np.array(result.full_token_scores[0]).shape == (2,)
assert result.logprobs().shape == (1,)
@@ -43,7 +46,8 @@ def test_llm_score_two():
assert type(result) is lmql.ScoringResult
assert len(result.seqs) == 2
- assert len(result.token_scores) == 2 and result.token_scores[0].shape == (2,)
+ assert len(result.token_scores) == 2 and result.token_scores[0].shape == (1,)
+ assert len(result.full_token_scores) == 2 and np.array(result.full_token_scores[0]).shape == (2,)
assert result.argmax() in ["World", "Test"]
@@ -58,7 +62,7 @@ def test_llm_openai():
print("Skipping test_api.test_llm_openai because no OpenAI API configuration could be found.")
return
- m = lmql.model("openai/gpt-3.5-turbo-instruct")
+ m = lmql.model("openai/gpt-3.5-turbo-instruct", silent=True)
assert m.score_sync("Hello", ["World", "Test"]).argmax() == "World"
if __name__ == "__main__":