From 69c4ce32a447e67ae5e962a09b70b877335a51c5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 24 Apr 2024 21:08:43 +0000 Subject: [PATCH] first draft of DynamicLexer --- controllers/aici_abi/src/dlex.rs | 266 +++++++++++++++++++++++++++ controllers/aici_abi/src/lib.rs | 2 + controllers/pyctrl/samples/idents.py | 29 +++ controllers/pyctrl/src/pyctrl.rs | 88 ++++++++- py/pyaici/server.py | 1 + py/pyaici/server_native.py | 26 +++ 6 files changed, 405 insertions(+), 7 deletions(-) create mode 100644 controllers/aici_abi/src/dlex.rs create mode 100644 controllers/pyctrl/samples/idents.py diff --git a/controllers/aici_abi/src/dlex.rs b/controllers/aici_abi/src/dlex.rs new file mode 100644 index 00000000..02f04313 --- /dev/null +++ b/controllers/aici_abi/src/dlex.rs @@ -0,0 +1,266 @@ +use crate::{ + recognizer::{FunctionalRecognizer, StackRecognizer}, + svob::SimpleVob, + toktree::SpecialToken, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct NodeId(u32); + +impl NodeId { + const NULL: NodeId = NodeId(0); + const ROOT: NodeId = NodeId(1); +} + +#[derive(Debug, Default, Clone)] +pub struct NodeData { + pub is_terminal: bool, +} + +enum TrieNode { + Sparse { + data: NodeData, + children: Vec<(u8, NodeId)>, + }, + Dense { + data: NodeData, + children: Vec, + }, +} + +impl TrieNode { + fn new_dense(data: NodeData, children: &Vec<(u8, NodeId)>) -> Self { + let mut dense_children = vec![NodeId::NULL; 256]; + for (byte, node_id) in children { + dense_children[*byte as usize] = *node_id; + } + TrieNode::Dense { + data, + children: dense_children, + } + } + + fn new_leaf() -> Self { + TrieNode::Sparse { + data: NodeData::default(), + children: vec![], + } + } + + fn data(&self) -> &NodeData { + match self { + TrieNode::Sparse { data, .. } => data, + TrieNode::Dense { data, .. } => data, + } + } + + fn data_mut(&mut self) -> &mut NodeData { + match self { + TrieNode::Sparse { data, .. } => data, + TrieNode::Dense { data, .. } => data, + } + } +} + +pub struct Trie { + nodes: Vec, +} + +impl Trie { + const MAX_SPARSE: usize = 8; + + pub fn new() -> Self { + Trie { + nodes: vec![ + TrieNode::new_leaf(), + TrieNode::new_dense(NodeData::default(), &vec![]), + ], + } + } + + fn node(&self, node_id: NodeId) -> &TrieNode { + &self.nodes[node_id.0 as usize] + } + + fn node_mut(&mut self, node_id: NodeId) -> &mut TrieNode { + &mut self.nodes[node_id.0 as usize] + } + + pub fn node_data(&self, node_id: NodeId) -> &NodeData { + self.node(node_id).data() + } + + pub fn root(&self) -> NodeId { + NodeId::ROOT + } + + pub fn child_at(&self, start: NodeId, b: u8) -> Option { + match self.node(start) { + TrieNode::Sparse { children, .. } => { + children.iter().find_map( + |&(byte, node_id)| { + if byte == b { + Some(node_id) + } else { + None + } + }, + ) + } + TrieNode::Dense { children, .. } => { + let node_id = children[b as usize]; + if node_id == NodeId::NULL { + None + } else { + Some(node_id) + } + } + } + } + + pub fn lookup(&self, start: NodeId, word: &[u8]) -> Option { + let mut node_id = start; + for &byte in word { + match self.child_at(node_id, byte) { + Some(child_id) => { + node_id = child_id; + } + None => { + return None; + } + } + } + Some(node_id) + } + + pub fn add(&mut self, word: &[u8]) { + let mut node_id = NodeId::ROOT; + for &byte in word { + let new_node_id = NodeId(self.nodes.len() as u32); + let node = self.node_mut(node_id); + match node { + TrieNode::Sparse { data, children } => { + match children.iter().find(|&&(b, _)| b == byte) { + Some(&(_, child_id)) => { + node_id = child_id; + } + None => { + children.push((byte, new_node_id)); + if children.len() > Trie::MAX_SPARSE { + self.nodes[node_id.0 as usize] = + TrieNode::new_dense(data.clone(), children); + } + self.nodes.push(TrieNode::new_leaf()); + node_id = new_node_id; + } + } + } + TrieNode::Dense { children, .. } => { + node_id = children[byte as usize]; + if node_id == NodeId::NULL { + children[byte as usize] = new_node_id; + self.nodes.push(TrieNode::new_leaf()); + node_id = new_node_id; + } + } + } + } + + self.node_mut(node_id).data_mut().is_terminal = true; + } +} + +pub struct DynamicLexer { + trie: Trie, + id_start: SimpleVob, + id_body: SimpleVob, +} + +#[derive(Debug, Clone, Copy)] +pub struct DState { + node_id: NodeId, +} + +impl DState { + const ROOT: DState = DState { + node_id: NodeId::ROOT, + }; +} + +pub type DynamicLexerRec = StackRecognizer; + +impl DynamicLexer { + pub fn new(additional_id_chars: &Vec) -> Self { + let mut id_start = SimpleVob::alloc(0x100); + let mut id_body = SimpleVob::alloc(0x100); + for i in 0..=255u8 { + match i as char { + 'a'..='z' | 'A'..='Z' | '_' => { + id_start.allow_token(i as u32); + id_body.allow_token(i as u32); + } + '0'..='9' => { + id_body.allow_token(i as u32); + } + _ => {} + } + } + for &c in additional_id_chars { + id_start.allow_token(c as u32); + id_body.allow_token(c as u32); + } + DynamicLexer { + trie: Trie::new(), + id_start, + id_body, + } + } + + pub fn to_stack_recognizer(self) -> StackRecognizer { + StackRecognizer::from(self) + } + + pub fn add(&mut self, word: &[u8]) { + self.trie.add(word); + } +} + +impl FunctionalRecognizer for DynamicLexer { + fn initial(&self) -> DState { + DState::ROOT + } + + fn try_append(&self, state: DState, byte: u8) -> Option { + if state.node_id == NodeId::ROOT { + if self.id_start.is_allowed(byte as u32) { + match self.trie.child_at(state.node_id, byte) { + Some(node_id) => Some(DState { node_id }), + None => None, + } + } else { + Some(state) + } + } else { + if self.id_body.is_allowed(byte as u32) { + match self.trie.child_at(state.node_id, byte) { + Some(node_id) => Some(DState { node_id }), + None => None, + } + } else { + if self.trie.node_data(state.node_id).is_terminal { + Some(DState::ROOT) + } else { + None + } + } + } + } + + fn special_allowed(&self, state: DState, tok: SpecialToken) -> bool { + if tok == SpecialToken::EndOfSentence { + self.trie.node_data(state.node_id).is_terminal + } else { + false + } + } +} diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 5a6e66e3..e4ccb2fd 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -16,6 +16,8 @@ mod lex; #[cfg(feature = "rx")] pub mod rx; +pub mod dlex; + pub mod substring; pub type TokenId = bytes::TokenId; diff --git a/controllers/pyctrl/samples/idents.py b/controllers/pyctrl/samples/idents.py new file mode 100644 index 00000000..0eba84dd --- /dev/null +++ b/controllers/pyctrl/samples/idents.py @@ -0,0 +1,29 @@ +import pyaici.server as aici +import re + +# asserts for microsoft/Orca-2-13b + +aici.log_level = 1 + +async def test_id(): + await aici.FixedTokens("Here's a fib function\n```python\n") + + max_tokens = 60 + dyn_lex = aici.DynamicLexer("") + for id in ["def", "fibo", "n", "return", "if"]: + dyn_lex.add(id) + next_token = aici.ConstrainedToken(lambda: dyn_lex.constraint()) + res = [] + text = "" + for _ in range(max_tokens): + tokens = await next_token + if tokens: + res += tokens + print("GEN-STEP:", aici.tokens_repr(tokens)) + text = aici.detokenize(res).decode(errors="replace") + if next_token.finished: + break + print("RESULT:", text) + + +aici.test(test_id()) diff --git a/controllers/pyctrl/src/pyctrl.rs b/controllers/pyctrl/src/pyctrl.rs index d42a5fb7..3f2284d8 100644 --- a/controllers/pyctrl/src/pyctrl.rs +++ b/controllers/pyctrl/src/pyctrl.rs @@ -12,7 +12,11 @@ use rustpython_vm::{ builtins::*, compiler::parser::ast::bigint::BigInt, AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, }; -use std::{ops::Deref, sync::Mutex, vec}; +use std::{ + ops::Deref, + sync::{Arc, Mutex}, + vec, +}; struct ModuleState { cb_obj: Option, @@ -42,9 +46,10 @@ fn get_cb_obj() -> PyObjectRef { #[rustpython_derive::pymodule] mod _aici { - use crate::{PyConstraint, VmExt, GLOBAL_STATE}; + use crate::{ConstraintWrapper, PyConstraint, VmExt, GLOBAL_STATE}; use aici_abi::{ cfg::CfgParser, + dlex::{self, DynamicLexerRec}, recognizer::{AnythingGoes, StackRecognizer}, rx::RecRx, substring::SubStrMatcher, @@ -62,7 +67,10 @@ mod _aici { types::{AsSequence, Constructor, Representable}, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; - use std::{fmt::Debug, sync::Mutex}; + use std::{ + fmt::Debug, + sync::{Arc, Mutex}, + }; #[pyfunction] fn register(obj: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> { @@ -157,6 +165,10 @@ mod _aici { #[pyclass(flags(BASETYPE), with(Constructor))] impl Constraint { + fn new(obj: impl PyConstraint + 'static) -> Self { + Constraint(Mutex::new(Box::new(obj))) + } + #[pymethod] fn eos_allowed(&self) -> bool { let mut s = self.0.lock().unwrap(); @@ -189,16 +201,55 @@ mod _aici { } } + #[pyattr] + #[pyclass(name)] + #[derive(PyPayload)] + pub struct DynamicLexer(pub Arc>); + + impl Debug for DynamicLexer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DynamicLexer").finish() + } + } + + #[pyclass(with(Constructor))] + impl DynamicLexer { + #[pymethod] + fn add(&self, word: PyStrRef) { + let mut lexer = self.0.lock().unwrap(); + lexer.recognizer_mut().add(word.as_str().as_bytes()); + } + + #[pymethod] + fn constraint(&self) -> PyResult { + Ok(Constraint::new(ConstraintWrapper(self.0.clone()))) + } + } + + impl Constructor for DynamicLexer { + type Args = (Option,); + fn py_new(cls: PyTypeRef, arg: Self::Args, vm: &VirtualMachine) -> PyResult { + let id_chars = match arg.0 { + Some(id_chars) => id_chars.as_str().chars().collect(), + None => vec![], + }; + let lexer = dlex::DynamicLexer::new(&id_chars).to_stack_recognizer(); + DynamicLexer(Arc::new(Mutex::new(lexer))) + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + #[pyfunction(name = "RegexConstraint")] fn regex_constraint(regex: PyStrRef) -> PyResult { let rx = RecRx::from_rx(regex.as_str()).to_stack_recognizer(); - Ok(Constraint(Mutex::new(Box::new(rx)))) + Ok(Constraint::new(rx)) } #[pyfunction(name = "CfgConstraint")] fn cfg_constraint(cfg: PyStrRef, vm: &VirtualMachine) -> PyResult { match CfgParser::from_yacc(cfg.as_str()) { - Ok(cfg) => Ok(Constraint(Mutex::new(Box::new(cfg)))), + Ok(cfg) => Ok(Constraint::new(cfg)), Err(e) => Err(vm.new_runtime_error(format!("{}", e))), } } @@ -206,14 +257,14 @@ mod _aici { #[pyfunction(name = "SubStrConstraint")] fn substr_constraint(templ: PyStrRef, end_str: PyStrRef) -> PyResult { let rx = SubStrMatcher::new(templ.as_str(), end_str.as_str()).to_stack_recognizer(); - Ok(Constraint(Mutex::new(Box::new(rx)))) + Ok(Constraint::new(rx)) } impl Constructor for Constraint { type Args = FuncArgs; fn py_new(cls: PyTypeRef, _arg: Self::Args, vm: &VirtualMachine) -> PyResult { let anything = StackRecognizer::from(AnythingGoes {}); - Constraint(Mutex::new(Box::new(anything))) + Constraint::new(anything) .into_ref_with_type(vm, cls) .map(Into::into) } @@ -433,6 +484,29 @@ impl PyConstraint for T { } } +struct ConstraintWrapper(Arc>); +impl PyConstraint for ConstraintWrapper { + fn eos_allowed(&mut self) -> bool { + self.0.lock().unwrap().eos_allowed() + } + + fn eos_forced(&mut self) -> bool { + self.0.lock().unwrap().eos_forced() + } + + fn token_allowed(&mut self, t: TokenId) -> bool { + self.0.lock().unwrap().token_allowed(t) + } + + fn append_token(&mut self, t: TokenId) { + self.0.lock().unwrap().append_token(t) + } + + fn allow_tokens(&mut self, logits: &mut SimpleVob) { + self.0.lock().unwrap().allow_tokens(logits) + } +} + trait VmExt { fn get_vm(&self) -> &VirtualMachine; diff --git a/py/pyaici/server.py b/py/pyaici/server.py index 7b95e84b..c10f447d 100644 --- a/py/pyaici/server.py +++ b/py/pyaici/server.py @@ -14,6 +14,7 @@ RegexConstraint, CfgConstraint, SubStrConstraint, + DynamicLexer, Constraint, get_config, get_var, diff --git a/py/pyaici/server_native.py b/py/pyaici/server_native.py index 0bfb44bf..e0e0b102 100644 --- a/py/pyaici/server_native.py +++ b/py/pyaici/server_native.py @@ -187,6 +187,32 @@ class SubStrConstraint(Constraint): def __init__(self, template: str, stop_at: str): ... +class DynamicLexer: + """ + A lexer with a set of valid identifiers, that can be used as a Constraint. + """ + + def __init__(self, additional_id_chars: str): + """ + Normally, identifiers match /[a-zA-Z_][a-zA-Z0-9_]*/. + If additional_id_chars is not empty, the chars are additionally allowed anywhere in the identifier. + For example, use "$" for JavaScript, or "'" for ML-like languages. + You can add "." but it will interfere with floats. + """ + ... + + def add(self, identifier: str): + """ + Allow given identifier. + """ + ... + + def constraint(self) -> Constraint: + """ + This always returns the same constraint. + """ + ... + def is_server_side(): """