Skip to content

Commit

Permalink
adding hacks for tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 14, 2024
1 parent 70d27fe commit cf3984b
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 28 deletions.
11 changes: 9 additions & 2 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ impl TokenParser {
if chop_bytes <= grm_bytes.len() {
self.llm_bytes = grm_bytes[0..grm_bytes.len() - chop_bytes].to_vec();
self.llm_tokens = self.token_env.tokenize_bytes(&self.llm_bytes);
let decoded = self.token_env.tok_trie().decode(&self.llm_tokens);
if self.llm_bytes.len() > 0 && &decoded[1..] == &self.llm_bytes && decoded[0] == b' ' {
infoln!(self, "applying <s>space hack");
self.grm_prefix = decoded[0..1].to_vec();
self.llm_bytes = decoded;
}
infoln!(self, "ini_tokens: {}", trie.tokens_dbg(&self.llm_tokens));
} else {
// pretend the final bit of prompt was the prefix of the grammar
Expand Down Expand Up @@ -191,7 +197,7 @@ impl TokenParser {
// TODO maybe remove in future
if self.llm_bytes != trie.decode(&self.llm_tokens) {
panic!(
"llm_bytes mismatch: {:?} {:?}",
"llm_bytes mismatch:\n {:?}\n {:?}",
String::from_utf8_lossy(&self.llm_bytes),
String::from_utf8_lossy(&trie.decode(&self.llm_tokens))
);
Expand Down Expand Up @@ -271,12 +277,13 @@ impl TokenParser {
let mut grm_tokens = self.token_env.tokenize_bytes(&new_forced);
infoln!(
self,
"forced: {} {:?} {:?}",
"forced: {} bytes:{:?} tokens:{:?}",
trie.tokens_dbg(&grm_tokens),
new_forced,
grm_tokens
);
let (chop_tokens, chop_bytes) = trie.chop_tokens(&mut self.parser, &grm_tokens);
infoln!(self, "chop: {} tokens, {} bytes", chop_tokens, chop_bytes);
token_prefix = new_forced[new_forced.len() - chop_bytes..].to_vec();
// here we remove a suffix from grm_tokens that could be possibly tokenized differently
grm_tokens.truncate(grm_tokens.len() - chop_tokens);
Expand Down
2 changes: 2 additions & 0 deletions py/llguidance/python/llguidance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ._lib import LLTokenizer, LLInterpreter
from ._tokenizer import TokenizerWrapper

__all__ = [
"LLTokenizer",
"LLInterpreter",
"TokenizerWrapper",
]
25 changes: 18 additions & 7 deletions py/llguidance/python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from typing import List, Tuple, Mapping, Optional, Sequence, Union
from ._util import TokenId
from ._tokenizer import TokenizerWrapper

class LLTokenizer:
vocab_size: int
eos_token: TokenId

def __new__(
cls,
eos_token: TokenId,
tokens: Sequence[bytes],
tokenizer: TokenizerWrapper,
) -> "LLTokenizer":
"""
Create a new tokenizer.
Args:
eos_token: TokenId - the identifier of the end-of-sequence/end-of-text token
tokens: Sequence[bytes] - maps regular tokens to their byte-string representation
special tokens need to map to empty bytes
"""

def greedy_tokenize(self, text: str) -> List[int]:
Expand All @@ -24,6 +20,19 @@ class LLTokenizer:
This will not necesserily match BPE.
"""

def tokenize_bytes(self, utf8bytes: bytes) -> List[int]:
"""
Tokenize the text as bytes.
This will use the underlaying Python tokenizer to tokenize valid UTF8
prefix of the text, and then fallback to greedy_tokenize() for the last
few bytes.
"""

def tokenize_str(self, text: str) -> List[int]:
"""
Same as tokenize_bytes, but for strings.
"""

def dbg_tokens(self, tokens: List[int]) -> str:
"""
Return a debug string representation of the tokens.
Expand Down Expand Up @@ -63,7 +72,9 @@ class LLInterpreter:
Returns the adjusted prompt.
"""

def mid_process(self, backtrack: int, tokens: List[TokenId]) -> Tuple[Optional[bytes], str]:
def mid_process(
self, backtrack: int, tokens: List[TokenId]
) -> Tuple[Optional[bytes], str]:
"""
Perform next parsing step.
Returns: optional token mask and a JSON string.
Expand Down
34 changes: 34 additions & 0 deletions py/llguidance/python/llguidance/_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import List, Optional, Sequence
from ._util import TokenId


class TokenizerWrapper:
eos_token_id: TokenId
bos_token_id: Optional[TokenId]
tokens: Sequence[bytes]

def __init__(self, gtokenizer) -> None:
self.gtokenizer = gtokenizer
self.eos_token_id = gtokenizer.eos_token_id
self.bos_token_id = gtokenizer.bos_token_id
self.tokens = gtokenizer.tokens
self.accepts_bytes = True
try:
gtokenizer(b"test")
except:
self.accepts_bytes = False
# If the tokenizer used bytes, then b"\xff" would be better (since it's invalid UTF-8)
# For now, we'll settle for "\x02" as assume it doesn't start any other token
self.prefix_string = "\x02"
self.prefix_tokens = self._encode_string(self.prefix_string)

def _encode_string(self, s: str) -> List[TokenId]:
if self.accepts_bytes:
return self.gtokenizer(s.encode("utf-8"))
else:
return self.gtokenizer(s)

def __call__(self, s: str):
tokens = self._encode_string(self.prefix_string + s)
assert tokens[: len(self.prefix_tokens)] == self.prefix_tokens
return tokens[len(self.prefix_tokens) :]
46 changes: 32 additions & 14 deletions py/llguidance/rust/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ struct LLInterpreter {
log_level: isize,
}

#[derive(Clone)]
#[pyclass]
struct LLTokenizer {
tok_trie: Arc<toktree::TokTrie>,
tokenizer_fun: Py<PyAny>,
#[allow(dead_code)]
tok_bos: Option<u32>,
}

#[pymethods]
Expand All @@ -31,9 +35,7 @@ impl LLInterpreter {
llguidance_json: &str,
log_level: Option<isize>,
) -> PyResult<Self> {
let env = PyTokenizer {
inner: tokenizer.tok_trie.clone(),
};
let env = tokenizer.clone();
let arg: TopLevelGrammar = serde_json::from_str(llguidance_json)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
let log_level = log_level.unwrap_or(1);
Expand Down Expand Up @@ -104,18 +106,39 @@ struct PyMidProcessResult {
#[pymethods]
impl LLTokenizer {
#[new]
fn py_new(eos_token: u32, tokens: Vec<Vec<u8>>) -> PyResult<Self> {
fn py_new(gtokenizer: Bound<'_, PyAny>) -> PyResult<Self> {
let tok_eos = gtokenizer.getattr("eos_token_id")?.extract::<u32>()?;
let tok_bos = gtokenizer
.getattr("bos_token_id")?
.extract::<u32>()
.map_or(None, |v| Some(v));
let tokens = gtokenizer.getattr("tokens")?.extract::<Vec<Vec<u8>>>()?;
let info = TokRxInfo {
vocab_size: tokens.len() as u32,
tok_eos: eos_token,
tok_eos,
};

let tok_trie = TokTrie::from(&info, &tokens);
Ok(LLTokenizer {
tok_trie: Arc::new(tok_trie),
tokenizer_fun: gtokenizer.into(),
tok_bos,
})
}

fn tokenize_bytes(&self, utf8bytes: &[u8]) -> Vec<TokenId> {
self.tok_trie.tokenize_with_greedy_fallback(utf8bytes, |s| {
Python::with_gil(|py| {
let r = self.tokenizer_fun.call1(py, (s,)).unwrap();
r.extract::<Vec<TokenId>>(py).unwrap()
})
})
}

fn tokenize_str(&self, text: &str) -> Vec<TokenId> {
self.tokenize_bytes(text.as_bytes())
}

fn greedy_tokenize(&self, text: &str) -> Vec<u32> {
self.tok_trie.greedy_tokenize(text.as_bytes())
}
Expand Down Expand Up @@ -144,22 +167,17 @@ impl LLTokenizer {
}
}

struct PyTokenizer {
inner: Arc<toktree::TokTrie>,
}

impl TokenizerEnv for PyTokenizer {
impl TokenizerEnv for LLTokenizer {
fn stop(&self) -> ! {
panic!("STOP"); // TODO
panic!("STOP"); // TODO?
}

fn tok_trie(&self) -> &toktree::TokTrie {
&self.inner
&self.tok_trie
}

fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
// TODO this should call out to the Python tokenizer
self.inner.greedy_tokenize(s)
self.tokenize_bytes(s)
}
}

Expand Down
20 changes: 15 additions & 5 deletions py/llguidance/t1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing import List

log_level = 2

def softmax(logits: np.ndarray, temperature=1.0) -> np.ndarray:
# Adjust logits by temperature
Expand All @@ -31,7 +32,7 @@ def run_constraint(tok: llguidance.LLTokenizer, e: LlamaCppEngine, grm: guidance
max_tokens = 100
serialized = grm.ll_serialize()
serialized["max_tokens"] = max_tokens
interp = llguidance.LLInterpreter(tok, json.dumps(serialized))
interp = llguidance.LLInterpreter(tok, json.dumps(serialized), log_level=log_level)
tokens = []
if e.tokenizer.bos_token_id is not None:
tokens.append(e.tokenizer.bos_token_id)
Expand Down Expand Up @@ -65,11 +66,20 @@ def run_constraint(tok: llguidance.LLTokenizer, e: LlamaCppEngine, grm: guidance


def main():
# m = guidance.models.Transformers(model="../../tmp/Phi-3-mini-128k-instruct/", trust_remote_code=True)
#m = guidance.models.Transformers(model="../../tmp/Phi-3-mini-128k-instruct/", trust_remote_code=True)
m = guidance.models.LlamaCpp(model="../../tmp/Phi-3-mini-4k-instruct-q4.gguf")
t: Tokenizer = m.engine.tokenizer
tok = llguidance.LLTokenizer(t.eos_token_id, t.tokens)
run_constraint(tok, m.engine, "Here's a joke: " + guidance.gen(regex="[a-z ]+", stop="\n"))
t = llguidance.TokenizerWrapper(m.engine.tokenizer)
t = llguidance.LLTokenizer(t)
assert t.tokenize_str("") == []
assert t.tokenize_str(" ") == [29871]
assert t.tokenize_str("x") == [29916]
assert t.tokenize_str("Hello world") == [10994, 3186]

assert t.tokenize_bytes(b"Hello world") == [10994, 3186]
assert t.tokenize_bytes(b"Hello world\xff") == [10994, 3186, 258]
assert t.tokenize_bytes(b"Hello world\xc0\xff") == [10994, 3186, 195, 258]

run_constraint(t, m.engine, "Here's a joke: " + guidance.gen(regex="[a-z ]+", stop="\n"))


if __name__ == "__main__":
Expand Down

0 comments on commit cf3984b

Please sign in to comment.