diff --git a/.vscode/settings.json b/.vscode/settings.json index 86dfe621..b9e68f91 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "[python]": { - "editor.defaultFormatter": "eeyore.yapf" + "editor.defaultFormatter": "ms-python.black-formatter" }, "python.formatting.provider": "none", "rust-analyzer.linkedProjects": [ diff --git a/controllers/guidance_ctrl/run_g.py b/controllers/guidance_ctrl/run_g.py index ae3d8113..a9fe8bab 100644 --- a/controllers/guidance_ctrl/run_g.py +++ b/controllers/guidance_ctrl/run_g.py @@ -80,6 +80,10 @@ def main(): + "/10\n" ) grm = "this is a test" + gen("test", max_tokens=10) + grm = "Tweak this proverb to apply to model instructions instead.\n" + gen( + "verse", max_tokens=2 + ) + # grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") # read current script file # with open(__file__) as f: diff --git a/controllers/guidance_ctrl/src/earley/parser.rs b/controllers/guidance_ctrl/src/earley/parser.rs index 25832c53..20721710 100644 --- a/controllers/guidance_ctrl/src/earley/parser.rs +++ b/controllers/guidance_ctrl/src/earley/parser.rs @@ -340,29 +340,39 @@ impl Parser { pub fn filter_max_tokens(&mut self) { let mut dst = 0; + + self.row_infos.push(RowInfo { + byte: 0, + commit_item: Item::NULL, + token_idx: self.token_idx, + }); + for idx in 0..self.rows.len() { let range = self.rows[idx].item_indices(); self.rows[idx].first_item = dst; for i in range { let item = self.scratch.items[i]; let sym_data = self.item_sym_data(&item); - if sym_data.props.max_tokens != usize::MAX - && self.token_idx - self.row_infos[item.start_pos()].token_idx - >= sym_data.props.max_tokens - { - debug!( - " remove: {}-{} {}", - self.token_idx, - self.row_infos[item.start_pos()].token_idx, - self.item_to_string(&item) - ); - continue; + let max_tokens = sym_data.props.max_tokens; + if max_tokens != usize::MAX { + let start_token_idx = self.row_infos[item.start_pos() + 1].token_idx; + if self.token_idx - start_token_idx >= max_tokens { + debug!( + " remove: {}-{} {}", + self.token_idx, + start_token_idx, + self.item_to_string(&item) + ); + continue; + } } self.scratch.items[dst] = item; dst += 1; } self.rows[idx].last_item = dst; } + + self.row_infos.pop(); } pub fn force_bytes(&mut self) -> Vec { @@ -537,6 +547,11 @@ impl Parser { .collect::>(); } bytes.push(byte); + debug!( + " capture: {} {:?}", + var_name, + String::from_utf8_lossy(&bytes) + ); self.captures.push((var_name.clone(), bytes)); } diff --git a/controllers/guidance_ctrl/src/gctrl.rs b/controllers/guidance_ctrl/src/gctrl.rs index 8e12310b..6f101100 100644 --- a/controllers/guidance_ctrl/src/gctrl.rs +++ b/controllers/guidance_ctrl/src/gctrl.rs @@ -72,8 +72,8 @@ impl Runner { .iter() .rev() .filter(|(name, _)| seen.insert(name)) - .rev(); - for (name, val) in captures { + .collect::>(); + for (name, val) in captures.iter().rev() { let cap = Capture { object: "capture", name: name.clone(), diff --git a/controllers/guidance_ctrl/src/tokenparser.rs b/controllers/guidance_ctrl/src/tokenparser.rs index 266fec3a..993c61b3 100644 --- a/controllers/guidance_ctrl/src/tokenparser.rs +++ b/controllers/guidance_ctrl/src/tokenparser.rs @@ -48,7 +48,7 @@ impl TokenParser { pub fn bytes_since(&self, mut idx: usize) -> &[u8] { idx += self.grm_prefix.len(); - if idx >= self.llm_tokens.len() { + if idx >= self.llm_bytes.len() { return &[]; } &self.llm_bytes[idx..] @@ -182,6 +182,10 @@ impl TokenParser { trie.token_set_dbg(&set) ); + if set.num_set() == 0 { + return MidProcessResult::stop(); + } + return MidProcessResult::sample(set); } } diff --git a/pytest.ini b/pytest.ini index 7a8a55e9..417d7cdb 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ -[pytest] -testpaths = py/tests -addopts = -n 1 +;[pytest] +;testpaths = py/tests +;addopts = -n 1 diff --git a/scripts/test-guidance.sh b/scripts/test-guidance.sh new file mode 100755 index 00000000..e206b514 --- /dev/null +++ b/scripts/test-guidance.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +if [ "X$AZURE_GUIDANCE_URL" = "X" ] ; then + if [ "X$AICI_API_BASE" = "X" ] ; then + AICI_API_BASE="http://127.0.0.1:4242/v1/" + fi + AZURE_GUIDANCE_URL="$AICI_API_BASE" +fi +export AZURE_GUIDANCE_URL + +cd $(dirname $0)/../py/guidance +pytest --selected_model azure_guidance tests/models/test_azure_guidance.py