Skip to content

Commit

Permalink
api: scoring tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Sep 22, 2023
1 parent 5f454f0 commit 2e69efe
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 67 deletions.
119 changes: 67 additions & 52 deletions docs/source/lib/generations.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Generations API
# Generations API <span class="tag" data-tag-name="functions">NEW</span>

The *Generations API* is a small library with the goal of providing high-level access to LMQL core features, such as inference backends, constrained generation, and advanced caching and scoring. The Generations API is designed to be easy to use, and does not require users to write any LMQL themselves.
<div class="subtitle">A simple Python API for LMQL-based text generation and scoring.</div>

The *Generations API* is a lightweight library with the goal of providing high-level access to LMQL core features, such as various inference backends, constrained generation, and advanced caching and scoring. The API was designed to be easy to use and does not require users to write any LMQL themselves.

<div style="margin-bottom: -10pt"></div>

## Overview

To illustrate the Generations API, let's look at a simple example of generating and scoring text using the `openai/gpt-3.5-turbo-instruct` model:

Expand Down Expand Up @@ -28,120 +34,129 @@ The snippet above demonstrates the different components of the Generations API:

- [**`lmql.LLM.generate(...)`**](#lmql-generate) is a simple function to generating text completions based on a given prompt. This can be helpful to quickly obtain single-step completions, or to generate a list of completions for a given prompt.

- [**`lmql.LLM.score(...)`**](#lmql-score) allows you to directly access the scores, your model assigns to the tokenized representation of your input prompt and continuations. This can be helpful for tasks such as classification or ranking.
- [**`lmql.LLM.score(...)`**](#lmql-score) allows you to directly access the scores, your model assigns to the tokenized representation of your input prompt and continuations. This can be helpful for tasks such as classification or ranking.

The result is an [`lmql.ScoringResult`](https://github.com/eth-sri/lmql/blob/main/src/lmql/api/scoring.py) object, which contains the scores for each continuation, as well as the prompt and continuations used for scoring. It provides a convenient interface for score aggregation, normalization and maximum selection.


**Compatibility** For more advanced use cases, the Generation API seamlessly blends with standard LMQL, allowing users to gradually adopt the full language runtime over time, if their use cases require it.

**Implementation** Internally, the Generations API is implemented as a thin wrapper around LMQL, and thus benefits from all the features of LMQL, such as caching, parallelization, and more. The API is fully asynchronous, and can be used with any async framework, such as `asyncio`, `trio`, or `curio`. Alternatively, the API can also be used synchronously, using the `*_sync` variants of the API functions.
**Implementation** Internally, the Generations API is implemented as a thin wrapper around LMQL, and thus benefits from all the features of LMQL, such as caching, parallelization, and more. The API is fully asynchronous, and should be used with `asyncio`. Alternatively, all API funcationality is also available synchronously, using the `*_sync` variants of the functions.

## API Reference
## `lmql.LLM` Objects

The Generation API is available directly in the top-level namespace of the `lmql` module:
At the core, `lmql.LLM` objects represent a specific language model and provide methods for generation and scoring. An `lmql.LLM` is instantiated using `lmql.model(...)` and can be passed [as-is to LMQL query programs](../language/models.rst#loading-models) or to the top-level [`lmql.generate`](#lmql-generate) and [`lmql.score`](#lmql-score) functions.

### `lmql.generate(...)`
### `LLM.generate(...)`

```python
async def lmql.generate(
async def generate(
self,
prompt: str,
max_tokens: Optional[int] = None,
model: Optional[Union[LLM, str]] = None,
decoder: str,
**kwargs
) -> Union[str, List[str]]
```

`lmql.generate` generates text based on a given prompt, using a given model as the generation backend.
Generates a text completion based on a given prompt. Returns the full prompt + completion as one string.

**Arguments**

- `prompt`: The prompt to generate from.
- `max_tokens`: The maximum number of tokens to generate. If `None`, text is generated until the model returns an *end-of-sequence* token.
- `model`: The model to use for generation. If `None`, the default model is used.
- `**kwargs`: Additional keyword arguments are passed to the underlying LMQL query program. This can be used to specify the `decoder`, options like `chunksize` or `n`, or any other model or decoder arguments.
- `prompt: str`: The prompt to generate from.
- `max_tokens: Optional[int]`: The maximum number of tokens to generate. If `None`, text is generated until the model returns an *end-of-sequence* token.
- `decoder: str`: The [decoding algorithm](../language/decoders.md) to use for generation. Defaults to `"argmax"`.
- `**kwargs`: Additional keyword arguments that are passed to the underlying LMQL query program. These can be useful to specify options like `chunksize`, decoder arguments like `n`, or any other model or decoder-specific arguments.

**Return Value** The function returns a string or a list of strings, depending on the decoder in use (`decoder=argmax` yields a single sequence, `decoder="sample", n=2` yields two sequences, etc.).

**Return Value** The function returns a string or a list of strings, depending on the decoder in use (`argmax` yields a single sequence, `decoder="sample", n=2` yields two sequences, etc.).
**Asynchronous** The function is asynchronous and should be used with [`asyncio`](https://docs.python.org/3/library/asyncio.html) and with `await`. When run in parallel, multiple generations will be batched and parallelized across multiple calls to the same model. For synchronous use, you can rely on [`LLM.generate_sync`](#llm-generate_sync), note however, that in this case, the function will block the current thread until generation is complete.

### `LLM.generate_sync(...)`

### `lmql.generate_sync(...)`
```python
def generate_sync(self, *args, **kwargs):
```

Synchronous version of `lmql.generate`.
Synchronous version of [`lmql.LLM.generate`](#llm-generate).

### `lmql.score`
### `LLM.score(...)`

```python
async def score(
self,
prompt: str,
values: Union[str, List[str]],
model: Optional[Union[str, LLM]] = None,
**kwargs
values: Union[str, List[str]]
) -> lmql.ScoringResult
```

`lmql.score` scores different continuation `values` for a given `prompt`.
Scores different continuation `values` for a given `prompt`.

**Arguments**
For instance `await m.score("Hello", ["World", "Apples", "Oranges"])` would score the continuations `"Hello World"`, `"Hello Apples"` and `"Hello Oranges"`.

- `prompt`: The prompt to score from.
- `values`: The values to score.
- `model`: The model to use for scoring. If `None`, the default model is used.
**Arguments**

**Return Value** The function returns an `lmql.ScoringResult` object, which contains the scores for each value, as well as the prompt and values used for scoring.
- `prompt`: The prompt to use as a common prefix for all continuations.
- `values`: The continuation values to score. This can be a single string or a list of strings.

### `lmql.score_sync(...)`
**Return Value** The result is an [`lmql.ScoringResult`](https://github.com/eth-sri/lmql/blob/main/src/lmql/api/scoring.py) object, which contains the scores for each continuation, as well as the prompt and continuations used for scoring. It provides a convenient interface for score aggregation, normalization and maximum selection.

Synchronous version of `lmql.score`.
**Asynchronous** The function is asynchronous and should be used with [`asyncio`](https://docs.python.org/3/library/asyncio.html) and with `await`. When run in parallel, multiple generations will be batched and parallelized across multiple calls to the same model. For synchronous use, you can rely on [`LLM.score_sync`](#llm-score-sync).

### `lmql.set_default_model(...)`
### `LLM.score_sync(...)`

```python
def set_default_model(model: Union[str, LLM])
def score_sync(self, *args, **kwargs)
```

Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are specified.

This applies globally in the current process.
Synchronous version of [`lmql.LLM.score`](#llm-score).


### `lmql.LLM` Objects

`lmql.LLM` objects represent a specific model, and provide methods for generation and scoring. An `lmql.LLM` is instantiated using `lmql.model(...)` and can be passed as-is to LMQL query programs (in the `from` clause) or to the `lmql.generate` and `lmql.score` functions.

```python
def get_tokenizer(self) -> LMQLTokenizer
```
The Generation API is available directly in the top-level namespace of the `lmql` module:

Returns the tokenizer used by the model.
## `lmql.generate(...)`

```python
async def generate(
self,
async def lmql.generate(
prompt: str,
max_tokens: Optional[int] = None,
model: Optional[Union[LLM, str]] = None,
**kwargs
) -> Union[str, List[str]]
```

Model-bound version of [`lmql.generate`](#lmql-generate).
`lmql.generate` generates text completions based on a given prompt and behaves just like [`LLM.generate`](#llm-generate),
with the provided `model` instance or model name.

```python
def generate_sync(self, *args, **kwargs):
```
## `lmql.generate_sync(...)`

Synchronous version of `lmql.LLM.generate`.
Synchronous version of [`lmql.generate`](#lmql-generate).

## `lmql.score(...)`

```python
async def score(
self,
prompt: str,
values: Union[str, List[str]],
model: Optional[Union[str, LLM]] = None,
**kwargs
) -> lmql.ScoringResult
```

Model-bound version of [`lmql.score`](#lmql-score).
`lmql.score` scores different continuation `values` for a given `prompt` and behaves just like [`LLM.score`](#llm-score),
with the provided `model` instance or model name.

## `lmql.score_sync(...)`

Synchronous version of [`lmql.score`](#lmql-score).

## `lmql.set_default_model(...)`

```python
def score_sync(self, *args, **kwargs)
def set_default_model(model: Union[str, LLM])
```

Synchronous version of `lmql.LLM.score`.
Sets the model instance to be used when no 'from' clause or @lmql.query(model=<model>) are specified.

This applies globally in the current process.
14 changes: 11 additions & 3 deletions src/lmql/api/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@ class ScoringResult:
Provides methods to aggregate scores and return the best continuation.
"""

def __init__(self, prompt, continuations, seqs: List[dc.seq], model_identifier: str):
def __init__(self, prompt, continuations: List[str], num_value_tokens: List[int], seqs: List[dc.seq], model_identifier: str):
self.seqs = [s.expand() for s in seqs]
self.prompt = prompt
# the continuations that were scored
self.continuations = continuations
# per continuation, the number of tokens that originate from the appended continuation value
self.num_value_tokens = num_value_tokens
self.model_identifier = model_identifier

@property
def token_scores(self):
def full_token_scores(self):
return [s.logprobs for s in self.seqs]

@property
def token_scores(self):
return [s.logprobs[-self.num_value_tokens[i]:] for i,s in enumerate(self.seqs)]

def scores(self, agg="sum", **kwargs):
"""
Returns the sequence scores per continuation.
Expand Down Expand Up @@ -83,10 +90,11 @@ async def dc_score(model: dc.DcModel, prompt, values, **kwargs):

prompt_seq = dc.seq(model.tokenizer.tokenize(prompt, asbytes=True))
value_ids = [model.tokenizer.tokenize(value, asbytes=True) for value in values]
num_value_ids = [len(ids) for ids in value_ids]

kwargs.pop("internal", None)
all_tokens = []
all_scores = []

kwargs["noscore"] = False
return ScoringResult(prompt, values, await model.score([prompt_seq] * len(value_ids), value_ids, **kwargs), model.model_identifier)
return ScoringResult(prompt, values, num_value_ids, await model.score([prompt_seq] * len(value_ids), value_ids, **kwargs), model.model_identifier)
18 changes: 9 additions & 9 deletions src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ async def get_return_value(self, *args):

return LMQLResult(self.state.prompt, await self.get_all_vars(),self.interpreter.distribution_variable, self.interpreter.distribution_values)

async def score(self, *args, **kwargs):
async def score(self, values, **kwargs):
model = kwargs.get("model", None)
if model is not None:
return await score(self.prompt, *args, **kwargs)
return await dc_score(self.interpreter.dcmodel, self.prompt, *args, **kwargs)
return await score(self.prompt, values, **kwargs)
return await dc_score(self.interpreter.dcmodel, self.prompt, values, **kwargs)

@dataclass
class LMQLResult:
Expand Down Expand Up @@ -314,6 +314,11 @@ def set_model(self, model_handle: Union[str, LLM]):

self.model = model_handle

# prepare dcmodel
decoder_args = self.decoder_kwargs
self.model.adapter.decoder_args = {**decoder_args, **self.extra_kwargs}
self.dcmodel: dc.DcModel = self.model.adapter.get_dclib_model()

async def advance(self, state: PromptState):
if state.variable is not None:
return state
Expand Down Expand Up @@ -895,11 +900,6 @@ async def run(self, fct, *args, **kwargs):
tail=None)
self.root_state = await self.advance(self.root_state)

# prepare dcmodel
decoder_args = self.decoder_kwargs
self.model.adapter.decoder_args = {**decoder_args, **self.extra_kwargs}
self.dcmodel: dc.DcModel = self.model.adapter.get_dclib_model()

async def debug_out(decoder_step):
if PromptInterpreter.main != self:
return
Expand Down Expand Up @@ -934,7 +934,7 @@ async def debug_out(decoder_step):
# make sure that the initial prompt is not considered part of a variable
self.root_state = self.root_state.updated(variable_offset=n)

decoder_args = decoder_args.copy()
decoder_args = self.decoder_kwargs.copy()

# pass processor as decoder argument
decoder_args["modern_logits_processor"] = self.where_processor
Expand Down
1 change: 1 addition & 0 deletions src/lmql/tests/expr_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def run_all_tests(g):
tb = traceback.extract_tb(e.__traceback__)
tb = "\n".join(traceback.format_list(tb))
termcolor.cprint("FAILED\n{}".format(tb), "red")
termcolor.cprint("AssertionError: {}".format(e), "red")
print(e)

# wait for all tasks to finish
Expand Down
10 changes: 7 additions & 3 deletions src/lmql/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import lmql
import numpy as np

from lmql.tests.expr_test_utils import run_all_tests

def test_generate_sync():
Expand Down Expand Up @@ -32,7 +34,8 @@ def test_score_sync():

assert type(result) is lmql.ScoringResult
assert len(result.seqs) == 1
assert len(result.token_scores) == 1 and result.token_scores[0].shape == (2,)
assert len(result.token_scores) == 1 and result.token_scores[0].shape == (1,)
assert len(result.full_token_scores) == 1 and np.array(result.full_token_scores[0]).shape == (2,)

assert result.logprobs().shape == (1,)

Expand All @@ -43,7 +46,8 @@ def test_llm_score_two():

assert type(result) is lmql.ScoringResult
assert len(result.seqs) == 2
assert len(result.token_scores) == 2 and result.token_scores[0].shape == (2,)
assert len(result.token_scores) == 2 and result.token_scores[0].shape == (1,)
assert len(result.full_token_scores) == 2 and np.array(result.full_token_scores[0]).shape == (2,)

assert result.argmax() in ["World", "Test"]

Expand All @@ -58,7 +62,7 @@ def test_llm_openai():
print("Skipping test_api.test_llm_openai because no OpenAI API configuration could be found.")
return

m = lmql.model("openai/gpt-3.5-turbo-instruct")
m = lmql.model("openai/gpt-3.5-turbo-instruct", silent=True)
assert m.score_sync("Hello", ["World", "Test"]).argmax() == "World"

if __name__ == "__main__":
Expand Down

0 comments on commit 2e69efe

Please sign in to comment.