Skip to content

Commit

Permalink
Remove the need to copy all tokens during basic generation
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and rlouf committed May 1, 2024
1 parent 4934425 commit 164d1f0
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 15 deletions.
8 changes: 4 additions & 4 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import TYPE_CHECKING, List, NewType
from typing import TYPE_CHECKING, Iterable, NewType, Optional

from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide

Expand All @@ -20,7 +20,7 @@ def __init__(self, tokenizer: "Tokenizer"):
)
super().__init__(tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

Expand All @@ -39,7 +39,7 @@ def __init__(self, regex_string: str, tokenizer):
)
super().__init__(regex_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

Expand All @@ -58,7 +58,7 @@ def __init__(self, cfg_string: str, tokenizer):
)
super().__init__(cfg_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
return self.get_next_instruction(state).tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
Expand Down
15 changes: 11 additions & 4 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Protocol, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union

import interegular
from lark import Lark
Expand Down Expand Up @@ -38,10 +38,11 @@ class Generate:
Attributes
----------
tokens
The tokens that lead to a valid completion if generated.
The tokens that lead to a valid completion if generated. A value
of ``None`` indicates that all tokens are allowed.
"""

tokens: List[int]
tokens: Optional[List[int]]


Instruction = Union[Write, Generate]
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(self, tokenizer: "Tokenizer"):
def get_next_instruction(self, state: int) -> Instruction:
if self.is_final_state(state):
return Write([self.eos_token_id])
return Generate(list(self.vocabulary))
return Generate(None)

def get_next_state(self, state: int, token_id: int) -> int:
if token_id == self.eos_token_id or state == self.final_state:
Expand Down Expand Up @@ -330,6 +331,9 @@ def get_next_instruction(self, state: int) -> Instruction:
proposer = self.regex_fsm

instruction = proposer.get_next_instruction(state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
Expand Down Expand Up @@ -365,6 +369,9 @@ def get_next_instruction(self, state: int) -> Instruction:
self.reset_state = True

instruction = self.regex_fsm.get_next_instruction(self.start_state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
Expand Down
11 changes: 8 additions & 3 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -134,7 +134,9 @@ def get_next_fsm_states(
]


def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor":
def get_allowed_tokens(
fsms: List["Guide"], fsm_states: List[int]
) -> List[Optional[Iterable[int]]]:
"""Get the new instructions for each sequence from the finite-state machine.
Parameters
Expand Down Expand Up @@ -302,5 +304,8 @@ def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tenso

biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
if ids is not None:
biased_logits[i, ids] = logits[i, ids]
else:
biased_logits[i] = logits[i]
return biased_logits
4 changes: 2 additions & 2 deletions outlines/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from collections import defaultdict
from typing import DefaultDict, List, Optional, Type, Union
from typing import DefaultDict, Iterable, Optional, Type, Union

import torch
from pydantic import BaseModel
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
# apply the FSM to the generated tokens.
self._prefix = [-1]

def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]:
def __call__(self, batch_id: int, sent: torch.Tensor) -> Optional[Iterable[int]]:
"""Use the FSM to bias the logits before sampling the next token.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MockTokenizer:
with pytest.warns(UserWarning):
fsm = StopAtEosFSM(MockTokenizer())

assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.start_state) is None
assert fsm.allowed_token_ids(fsm.final_state) == [2]
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MockTokenizer:

instruction = fsm.get_next_instruction(fsm.start_state)
assert isinstance(instruction, Generate)
assert instruction.tokens == [1, 2]
assert instruction.tokens is None

instruction = fsm.get_next_instruction(fsm.final_state)
assert isinstance(instruction, Write)
Expand Down

0 comments on commit 164d1f0

Please sign in to comment.