From 8a287f7bdb96dff5b42936fe7762125e3fa52763 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 14 Jun 2024 18:33:40 +0000 Subject: [PATCH] support save_stop_text (stop_capture_name) --- controllers/llguidance_ctrl/README.md | 6 +- controllers/llguidance_ctrl/run_g.py | 6 ++ controllers/llguidance_ctrl/src/api.rs | 4 + .../src/earley/from_guidance.rs | 46 +++------ .../llguidance_ctrl/src/earley/grammar.rs | 97 ++++++++++++++----- .../llguidance_ctrl/src/earley/parser.rs | 12 +++ py/guidance | 2 +- 7 files changed, 107 insertions(+), 66 deletions(-) diff --git a/controllers/llguidance_ctrl/README.md b/controllers/llguidance_ctrl/README.md index 9387b708..69861ed6 100644 --- a/controllers/llguidance_ctrl/README.md +++ b/controllers/llguidance_ctrl/README.md @@ -15,9 +15,7 @@ Guidance branch: https://github.com/hudson-ai/guidance/tree/lazy_grammars ## Status in Guidance -- [ ] `gen_mode` in `_gen.py` needs to become a flag on the model/engine -- [ ] `gen_json()` needs to be re-implemented -- [ ] `save_stop_text=` doesn't work on `gen()` +- [x] `save_stop_text=` doesn't work on `gen()` - [ ] `substring()` needs to be re-implemented (translate to RegexAst) - [ ] translate `commit_point(grammar)` into a RegexAst if (the grammar is non-recursive; @@ -27,7 +25,5 @@ Guidance branch: https://github.com/hudson-ai/guidance/tree/lazy_grammars ## TODO - [ ] `to_regex_vec()` in lexerspec.rs - non-contextual keywords -- [x] handle stop tokens in `gen_grammar()` - stop tokens removed -- [x] use `RegexAst::Byte(0xff)` for `EOS_MARKER` - [ ] fix derivative computation to be non-recursive (critical for `substring()`) - [ ] add stats about how many parser transitions are made in a token trie traversal diff --git a/controllers/llguidance_ctrl/run_g.py b/controllers/llguidance_ctrl/run_g.py index a48f5570..b284dea1 100644 --- a/controllers/llguidance_ctrl/run_g.py +++ b/controllers/llguidance_ctrl/run_g.py @@ -224,6 +224,12 @@ def character_maker2(lm, id, description, valid_weapons): + gen("score", regex=r"[0-9]") ) + prompt = "" + grm = "Name: " + \ + gen('name', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop") + \ + "\nName: " + \ + gen('name2', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop2") + # grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n" diff --git a/controllers/llguidance_ctrl/src/api.rs b/controllers/llguidance_ctrl/src/api.rs index 2640b05d..e3518e55 100644 --- a/controllers/llguidance_ctrl/src/api.rs +++ b/controllers/llguidance_ctrl/src/api.rs @@ -111,6 +111,10 @@ pub struct GenOptions { /// If `stop_rx` is empty, it's assumed to be EOS. pub stop_rx: RegexSpec, + /// When set, the string matching `stop_rx` will be output as a capture + /// with the given name. + pub stop_capture_name: Option, + /// Override sampling temperature. pub temperature: Option, } diff --git a/controllers/llguidance_ctrl/src/earley/from_guidance.rs b/controllers/llguidance_ctrl/src/earley/from_guidance.rs index 91064ac7..a2079631 100644 --- a/controllers/llguidance_ctrl/src/earley/from_guidance.rs +++ b/controllers/llguidance_ctrl/src/earley/from_guidance.rs @@ -7,39 +7,6 @@ use crate::api::{ use anyhow::{bail, ensure, Result}; use derivre::{ExprRef, RegexAst, RegexBuilder}; -#[derive(Debug)] -pub struct NodeProps { - pub nullable: bool, - pub name: String, - pub hidden: bool, - pub commit_point: bool, - pub capture_name: String, - pub max_tokens: i32, - pub temperature: f32, -} - -impl NodeProps { - #[allow(dead_code)] - pub fn to_symbol_props(&self) -> SymbolProps { - SymbolProps { - commit_point: self.commit_point, - hidden: self.hidden && self.commit_point, - max_tokens: if self.max_tokens == i32::MAX { - usize::MAX - } else { - self.max_tokens.try_into().unwrap() - }, - model_variable: None, - capture_name: if self.capture_name.is_empty() { - None - } else { - Some(self.capture_name.clone()) - }, - temperature: self.temperature, - } - } -} - fn resolve_rx(rx_refs: &[ExprRef], node: &RegexSpec) -> Result { match node { RegexSpec::Regex(rx) => Ok(RegexAst::Regex(rx.clone())), @@ -128,6 +95,7 @@ fn grammar_from_json(input: GrammarWithLexer) -> Result<(LexerSpec, Grammar)> { model_variable: None, capture_name: props.capture_name.clone(), temperature: 0.0, + stop_capture_name: None, }; grm.fresh_symbol_ext(&name, symprops) }) @@ -165,11 +133,21 @@ fn grammar_from_json(input: GrammarWithLexer) -> Result<(LexerSpec, Grammar)> { body_rx, stop_rx, )?; - grm.make_terminal(lhs, idx)?; + let symprops = grm.sym_props_mut(lhs); if let Some(t) = data.temperature { symprops.temperature = t; } + if data.stop_capture_name.is_some() { + symprops.stop_capture_name = data.stop_capture_name.clone(); + let wrap_props = symprops.for_wrapper(); + let wrap_name = format!("stop_wrap_{}", grm.sym_name(lhs)); + let wrap_sym = grm.fresh_symbol_ext(&wrap_name, wrap_props); + grm.make_terminal(wrap_sym, idx)?; + grm.add_rule(lhs, vec![wrap_sym])?; + } else { + grm.make_terminal(lhs, idx)?; + } } Node::Lexeme { rx, contextual, .. } => { ensure!(is_greedy, "lexeme() only allowed in greedy grammars"); diff --git a/controllers/llguidance_ctrl/src/earley/grammar.rs b/controllers/llguidance_ctrl/src/earley/grammar.rs index 15a4b1c5..31b76b4c 100644 --- a/controllers/llguidance_ctrl/src/earley/grammar.rs +++ b/controllers/llguidance_ctrl/src/earley/grammar.rs @@ -66,6 +66,7 @@ pub struct SymbolProps { pub max_tokens: usize, pub commit_point: bool, pub capture_name: Option, + pub stop_capture_name: Option, pub hidden: bool, pub model_variable: Option, pub temperature: f32, @@ -79,6 +80,7 @@ impl Default for SymbolProps { max_tokens: usize::MAX, model_variable: None, capture_name: None, + stop_capture_name: None, temperature: 0.0, } } @@ -91,6 +93,51 @@ impl SymbolProps { || self.hidden || self.max_tokens < usize::MAX || self.capture_name.is_some() + || self.stop_capture_name.is_some() + } + + pub fn for_wrapper(&self) -> Self { + SymbolProps { + commit_point: false, + hidden: self.hidden, + max_tokens: self.max_tokens, + model_variable: None, + capture_name: None, + stop_capture_name: None, + temperature: self.temperature, + } + } + + pub fn to_string(&self) -> String { + let props = self; + let mut outp = String::new(); + + if props.commit_point { + if props.hidden { + outp.push_str(" HIDDEN-COMMIT"); + } else { + outp.push_str(" COMMIT"); + } + } + if props.capture_name.is_some() { + outp.push_str(" CAPTURE"); + } + + if props.stop_capture_name.is_some() { + outp.push_str( + format!( + " STOP-CAPTURE={}", + props.stop_capture_name.as_ref().unwrap() + ) + .as_str(), + ); + } + + if props.max_tokens < 10000 { + outp.push_str(format!(" max_tokens={}", props.max_tokens).as_str()); + } + + outp } } @@ -411,8 +458,20 @@ impl Debug for Grammar { num_non_term += 1; num_rules += sym.rules.len(); } - for rule in &sym.rules { - writeln!(f, "{}", self.rule_to_string(rule, None))?; + if sym.rules.is_empty() { + if sym.props.is_special() { + writeln!( + f, + "{:15} ⇦ {:?} {}", + sym.name, + sym.lexeme, + sym.props.to_string() + )?; + } + } else { + for rule in &sym.rules { + writeln!(f, "{}", self.rule_to_string(rule, None))?; + } } } writeln!( @@ -495,6 +554,7 @@ impl SymFlags { const HIDDEN: u8 = 1 << 2; const CAPTURE: u8 = 1 << 3; const GEN_GRAMMAR: u8 = 1 << 4; + const STOP_CAPTURE: u8 = 1 << 5; fn from_csymbol(sym: &CSymbol) -> Self { let mut flags = 0; @@ -510,6 +570,9 @@ impl SymFlags { if sym.gen_grammar.is_some() { flags |= Self::GEN_GRAMMAR; } + if sym.props.stop_capture_name.is_some() { + flags |= Self::STOP_CAPTURE; + } SymFlags(flags) } @@ -529,6 +592,11 @@ impl SymFlags { self.0 & Self::CAPTURE != 0 } + #[inline(always)] + pub fn stop_capture(&self) -> bool { + self.0 & Self::STOP_CAPTURE != 0 + } + #[inline(always)] pub fn gen_grammar(&self) -> bool { self.0 & Self::GEN_GRAMMAR != 0 @@ -800,28 +868,5 @@ fn rule_to_string( } else if let Some(dot) = dot { rhs.insert(dot, "•"); } - format!( - "{:15} ⇦ {} {}{}{}", - lhs, - rhs.join(" "), - if props.commit_point { - if props.hidden { - " HIDDEN-COMMIT" - } else { - " COMMIT" - } - } else { - "" - }, - if props.capture_name.is_some() { - " CAPTURE" - } else { - "" - }, - if props.max_tokens < 10000 { - format!(" max_tokens={}", props.max_tokens) - } else { - "".to_string() - }, - ) + format!("{:15} ⇦ {} {}", lhs, rhs.join(" "), props.to_string()) } diff --git a/controllers/llguidance_ctrl/src/earley/parser.rs b/controllers/llguidance_ctrl/src/earley/parser.rs index 6f382f9a..8dd8ba76 100644 --- a/controllers/llguidance_ctrl/src/earley/parser.rs +++ b/controllers/llguidance_ctrl/src/earley/parser.rs @@ -984,6 +984,18 @@ impl Parser { let flags = self.grammar.sym_flags_of(rule); let lhs = self.grammar.sym_idx_of(rule); + if self.scratch.definitive && flags.stop_capture() { + let var_name = self + .grammar + .sym_data(lhs) + .props + .stop_capture_name + .as_ref() + .unwrap(); + let bytes = lexeme.hidden_bytes(); + self.captures.push((var_name.clone(), bytes.to_vec())); + } + if self.scratch.definitive && flags.capture() { let var_name = self .grammar diff --git a/py/guidance b/py/guidance index 4db20be1..8966fcbd 160000 --- a/py/guidance +++ b/py/guidance @@ -1 +1 @@ -Subproject commit 4db20be1c2422be8c46cd499aff265907a2b58ca +Subproject commit 8966fcbd7b7498188b214f7422e90cb239ae564c