diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index d0340a1ad..4a7fce8c9 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -1,5 +1,5 @@ import warnings -from typing import TYPE_CHECKING, List, NewType +from typing import TYPE_CHECKING, Iterable, NewType, Optional from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide @@ -20,7 +20,7 @@ def __init__(self, tokenizer: "Tokenizer"): ) super().__init__(tokenizer) - def allowed_token_ids(self, state: FSMState) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: next_instruction = self.get_next_instruction(state) return next_instruction.tokens @@ -39,7 +39,7 @@ def __init__(self, regex_string: str, tokenizer): ) super().__init__(regex_string, tokenizer) - def allowed_token_ids(self, state: FSMState) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: next_instruction = self.get_next_instruction(state) return next_instruction.tokens @@ -58,7 +58,7 @@ def __init__(self, cfg_string: str, tokenizer): ) super().__init__(cfg_string, tokenizer) - def allowed_token_ids(self, state: FSMState) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: return self.get_next_instruction(state).tokens def next_state(self, state: FSMState, token_id: int) -> FSMState: diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index cf5be8d50..5c7b56326 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Protocol, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union import interegular from lark import Lark @@ -38,10 +38,11 @@ class Generate: Attributes ---------- tokens - The tokens that lead to a valid completion if generated. + The tokens that lead to a valid completion if generated. A value + of ``None`` indicates that all tokens are allowed. """ - tokens: List[int] + tokens: Optional[List[int]] Instruction = Union[Write, Generate] @@ -89,7 +90,7 @@ def __init__(self, tokenizer: "Tokenizer"): def get_next_instruction(self, state: int) -> Instruction: if self.is_final_state(state): return Write([self.eos_token_id]) - return Generate(list(self.vocabulary)) + return Generate(None) def get_next_state(self, state: int, token_id: int) -> int: if token_id == self.eos_token_id or state == self.final_state: @@ -330,6 +331,9 @@ def get_next_instruction(self, state: int) -> Instruction: proposer = self.regex_fsm instruction = proposer.get_next_instruction(state) + + assert instruction.tokens is not None + if isinstance(instruction, Write): proposal += instruction.tokens else: @@ -365,6 +369,9 @@ def get_next_instruction(self, state: int) -> Instruction: self.reset_state = True instruction = self.regex_fsm.get_next_instruction(self.start_state) + + assert instruction.tokens is not None + if isinstance(instruction, Write): proposal += instruction.tokens else: diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index ca5fa395f..edda617d0 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple if TYPE_CHECKING: import torch @@ -134,7 +134,9 @@ def get_next_fsm_states( ] -def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor": +def get_allowed_tokens( + fsms: List["Guide"], fsm_states: List[int] +) -> List[Optional[Iterable[int]]]: """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -302,5 +304,8 @@ def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tenso biased_logits = torch.full_like(logits, -math.inf, device=logits.device) for i, ids in enumerate(allowed_token_ids): - biased_logits[i, ids] = logits[i, ids] + if ids is not None: + biased_logits[i, ids] = logits[i, ids] + else: + biased_logits[i] = logits[i] return biased_logits diff --git a/outlines/integrations/transformers.py b/outlines/integrations/transformers.py index f8f1af945..7c1bafd22 100644 --- a/outlines/integrations/transformers.py +++ b/outlines/integrations/transformers.py @@ -26,7 +26,7 @@ """ from collections import defaultdict -from typing import DefaultDict, List, Optional, Type, Union +from typing import DefaultDict, Iterable, Optional, Type, Union import torch from pydantic import BaseModel @@ -84,7 +84,7 @@ def __init__( # apply the FSM to the generated tokens. self._prefix = [-1] - def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]: + def __call__(self, batch_id: int, sent: torch.Tensor) -> Optional[Iterable[int]]: """Use the FSM to bias the logits before sampling the next token. Parameters diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 8ce17c6eb..30047b83d 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -11,7 +11,7 @@ class MockTokenizer: with pytest.warns(UserWarning): fsm = StopAtEosFSM(MockTokenizer()) - assert fsm.allowed_token_ids(fsm.start_state) == [1, 2] + assert fsm.allowed_token_ids(fsm.start_state) is None assert fsm.allowed_token_ids(fsm.final_state) == [2] assert fsm.next_state(fsm.start_state, 2) == fsm.final_state assert fsm.next_state(fsm.start_state, 1) == fsm.start_state diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index aabf0446c..28645f012 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -12,7 +12,7 @@ class MockTokenizer: instruction = fsm.get_next_instruction(fsm.start_state) assert isinstance(instruction, Generate) - assert instruction.tokens == [1, 2] + assert instruction.tokens is None instruction = fsm.get_next_instruction(fsm.final_state) assert isinstance(instruction, Write)