Skip to content

Commit

Permalink
support save_stop_text (stop_capture_name)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 14, 2024
1 parent 790ec09 commit 8a287f7
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 66 deletions.
6 changes: 1 addition & 5 deletions controllers/llguidance_ctrl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ Guidance branch: https://github.com/hudson-ai/guidance/tree/lazy_grammars

## Status in Guidance

- [ ] `gen_mode` in `_gen.py` needs to become a flag on the model/engine
- [ ] `gen_json()` needs to be re-implemented
- [ ] `save_stop_text=` doesn't work on `gen()`
- [x] `save_stop_text=` doesn't work on `gen()`
- [ ] `substring()` needs to be re-implemented (translate to RegexAst)
- [ ] translate `commit_point(grammar)` into a RegexAst if
(the grammar is non-recursive;
Expand All @@ -27,7 +25,5 @@ Guidance branch: https://github.com/hudson-ai/guidance/tree/lazy_grammars
## TODO

- [ ] `to_regex_vec()` in lexerspec.rs - non-contextual keywords
- [x] handle stop tokens in `gen_grammar()` - stop tokens removed
- [x] use `RegexAst::Byte(0xff)` for `EOS_MARKER`
- [ ] fix derivative computation to be non-recursive (critical for `substring()`)
- [ ] add stats about how many parser transitions are made in a token trie traversal
6 changes: 6 additions & 0 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def character_maker2(lm, id, description, valid_weapons):
+ gen("score", regex=r"[0-9]")
)

prompt = ""
grm = "Name: " + \
gen('name', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop") + \
"\nName: " + \
gen('name2', regex="E[a-z]+", stop_regex=["[a-b]", "[x-z]"], save_stop_text="saved_name_stop2")


# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n"

Expand Down
4 changes: 4 additions & 0 deletions controllers/llguidance_ctrl/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ pub struct GenOptions {
/// If `stop_rx` is empty, it's assumed to be EOS.
pub stop_rx: RegexSpec,

/// When set, the string matching `stop_rx` will be output as a capture
/// with the given name.
pub stop_capture_name: Option<String>,

/// Override sampling temperature.
pub temperature: Option<f32>,
}
Expand Down
46 changes: 12 additions & 34 deletions controllers/llguidance_ctrl/src/earley/from_guidance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,6 @@ use crate::api::{
use anyhow::{bail, ensure, Result};
use derivre::{ExprRef, RegexAst, RegexBuilder};

#[derive(Debug)]
pub struct NodeProps {
pub nullable: bool,
pub name: String,
pub hidden: bool,
pub commit_point: bool,
pub capture_name: String,
pub max_tokens: i32,
pub temperature: f32,
}

impl NodeProps {
#[allow(dead_code)]
pub fn to_symbol_props(&self) -> SymbolProps {
SymbolProps {
commit_point: self.commit_point,
hidden: self.hidden && self.commit_point,
max_tokens: if self.max_tokens == i32::MAX {
usize::MAX
} else {
self.max_tokens.try_into().unwrap()
},
model_variable: None,
capture_name: if self.capture_name.is_empty() {
None
} else {
Some(self.capture_name.clone())
},
temperature: self.temperature,
}
}
}

fn resolve_rx(rx_refs: &[ExprRef], node: &RegexSpec) -> Result<RegexAst> {
match node {
RegexSpec::Regex(rx) => Ok(RegexAst::Regex(rx.clone())),
Expand Down Expand Up @@ -128,6 +95,7 @@ fn grammar_from_json(input: GrammarWithLexer) -> Result<(LexerSpec, Grammar)> {
model_variable: None,
capture_name: props.capture_name.clone(),
temperature: 0.0,
stop_capture_name: None,
};
grm.fresh_symbol_ext(&name, symprops)
})
Expand Down Expand Up @@ -165,11 +133,21 @@ fn grammar_from_json(input: GrammarWithLexer) -> Result<(LexerSpec, Grammar)> {
body_rx,
stop_rx,
)?;
grm.make_terminal(lhs, idx)?;

let symprops = grm.sym_props_mut(lhs);
if let Some(t) = data.temperature {
symprops.temperature = t;
}
if data.stop_capture_name.is_some() {
symprops.stop_capture_name = data.stop_capture_name.clone();
let wrap_props = symprops.for_wrapper();
let wrap_name = format!("stop_wrap_{}", grm.sym_name(lhs));
let wrap_sym = grm.fresh_symbol_ext(&wrap_name, wrap_props);
grm.make_terminal(wrap_sym, idx)?;
grm.add_rule(lhs, vec![wrap_sym])?;
} else {
grm.make_terminal(lhs, idx)?;
}
}
Node::Lexeme { rx, contextual, .. } => {
ensure!(is_greedy, "lexeme() only allowed in greedy grammars");
Expand Down
97 changes: 71 additions & 26 deletions controllers/llguidance_ctrl/src/earley/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub struct SymbolProps {
pub max_tokens: usize,
pub commit_point: bool,
pub capture_name: Option<String>,
pub stop_capture_name: Option<String>,
pub hidden: bool,
pub model_variable: Option<ModelVariable>,
pub temperature: f32,
Expand All @@ -79,6 +80,7 @@ impl Default for SymbolProps {
max_tokens: usize::MAX,
model_variable: None,
capture_name: None,
stop_capture_name: None,
temperature: 0.0,
}
}
Expand All @@ -91,6 +93,51 @@ impl SymbolProps {
|| self.hidden
|| self.max_tokens < usize::MAX
|| self.capture_name.is_some()
|| self.stop_capture_name.is_some()
}

pub fn for_wrapper(&self) -> Self {
SymbolProps {
commit_point: false,
hidden: self.hidden,
max_tokens: self.max_tokens,
model_variable: None,
capture_name: None,
stop_capture_name: None,
temperature: self.temperature,
}
}

pub fn to_string(&self) -> String {
let props = self;
let mut outp = String::new();

if props.commit_point {
if props.hidden {
outp.push_str(" HIDDEN-COMMIT");
} else {
outp.push_str(" COMMIT");
}
}
if props.capture_name.is_some() {
outp.push_str(" CAPTURE");
}

if props.stop_capture_name.is_some() {
outp.push_str(
format!(
" STOP-CAPTURE={}",
props.stop_capture_name.as_ref().unwrap()
)
.as_str(),
);
}

if props.max_tokens < 10000 {
outp.push_str(format!(" max_tokens={}", props.max_tokens).as_str());
}

outp
}
}

Expand Down Expand Up @@ -411,8 +458,20 @@ impl Debug for Grammar {
num_non_term += 1;
num_rules += sym.rules.len();
}
for rule in &sym.rules {
writeln!(f, "{}", self.rule_to_string(rule, None))?;
if sym.rules.is_empty() {
if sym.props.is_special() {
writeln!(
f,
"{:15} ⇦ {:?} {}",
sym.name,
sym.lexeme,
sym.props.to_string()
)?;
}
} else {
for rule in &sym.rules {
writeln!(f, "{}", self.rule_to_string(rule, None))?;
}
}
}
writeln!(
Expand Down Expand Up @@ -495,6 +554,7 @@ impl SymFlags {
const HIDDEN: u8 = 1 << 2;
const CAPTURE: u8 = 1 << 3;
const GEN_GRAMMAR: u8 = 1 << 4;
const STOP_CAPTURE: u8 = 1 << 5;

fn from_csymbol(sym: &CSymbol) -> Self {
let mut flags = 0;
Expand All @@ -510,6 +570,9 @@ impl SymFlags {
if sym.gen_grammar.is_some() {
flags |= Self::GEN_GRAMMAR;
}
if sym.props.stop_capture_name.is_some() {
flags |= Self::STOP_CAPTURE;
}
SymFlags(flags)
}

Expand All @@ -529,6 +592,11 @@ impl SymFlags {
self.0 & Self::CAPTURE != 0
}

#[inline(always)]
pub fn stop_capture(&self) -> bool {
self.0 & Self::STOP_CAPTURE != 0
}

#[inline(always)]
pub fn gen_grammar(&self) -> bool {
self.0 & Self::GEN_GRAMMAR != 0
Expand Down Expand Up @@ -800,28 +868,5 @@ fn rule_to_string(
} else if let Some(dot) = dot {
rhs.insert(dot, "•");
}
format!(
"{:15} ⇦ {} {}{}{}",
lhs,
rhs.join(" "),
if props.commit_point {
if props.hidden {
" HIDDEN-COMMIT"
} else {
" COMMIT"
}
} else {
""
},
if props.capture_name.is_some() {
" CAPTURE"
} else {
""
},
if props.max_tokens < 10000 {
format!(" max_tokens={}", props.max_tokens)
} else {
"".to_string()
},
)
format!("{:15} ⇦ {} {}", lhs, rhs.join(" "), props.to_string())
}
12 changes: 12 additions & 0 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,18 @@ impl Parser {
let flags = self.grammar.sym_flags_of(rule);
let lhs = self.grammar.sym_idx_of(rule);

if self.scratch.definitive && flags.stop_capture() {
let var_name = self
.grammar
.sym_data(lhs)
.props
.stop_capture_name
.as_ref()
.unwrap();
let bytes = lexeme.hidden_bytes();
self.captures.push((var_name.clone(), bytes.to_vec()));
}

if self.scratch.definitive && flags.capture() {
let var_name = self
.grammar
Expand Down
2 changes: 1 addition & 1 deletion py/guidance

0 comments on commit 8a287f7

Please sign in to comment.