Skip to content

Commit

Permalink
guidance ctrl fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 6, 2024
1 parent 31f2512 commit 1fd4e4d
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "eeyore.yapf"
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"rust-analyzer.linkedProjects": [
Expand Down
4 changes: 4 additions & 0 deletions controllers/guidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def main():
+ "/10\n"
)
grm = "this is a test" + gen("test", max_tokens=10)
grm = "Tweak this proverb to apply to model instructions instead.\n" + gen(
"verse", max_tokens=2
)
# grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")

# read current script file
# with open(__file__) as f:
Expand Down
37 changes: 26 additions & 11 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,29 +340,39 @@ impl Parser {

pub fn filter_max_tokens(&mut self) {
let mut dst = 0;

self.row_infos.push(RowInfo {
byte: 0,
commit_item: Item::NULL,
token_idx: self.token_idx,
});

for idx in 0..self.rows.len() {
let range = self.rows[idx].item_indices();
self.rows[idx].first_item = dst;
for i in range {
let item = self.scratch.items[i];
let sym_data = self.item_sym_data(&item);
if sym_data.props.max_tokens != usize::MAX
&& self.token_idx - self.row_infos[item.start_pos()].token_idx
>= sym_data.props.max_tokens
{
debug!(
" remove: {}-{} {}",
self.token_idx,
self.row_infos[item.start_pos()].token_idx,
self.item_to_string(&item)
);
continue;
let max_tokens = sym_data.props.max_tokens;
if max_tokens != usize::MAX {
let start_token_idx = self.row_infos[item.start_pos() + 1].token_idx;
if self.token_idx - start_token_idx >= max_tokens {
debug!(
" remove: {}-{} {}",
self.token_idx,
start_token_idx,
self.item_to_string(&item)
);
continue;
}
}
self.scratch.items[dst] = item;
dst += 1;
}
self.rows[idx].last_item = dst;
}

self.row_infos.pop();
}

pub fn force_bytes(&mut self) -> Vec<u8> {
Expand Down Expand Up @@ -537,6 +547,11 @@ impl Parser {
.collect::<Vec<_>>();
}
bytes.push(byte);
debug!(
" capture: {} {:?}",
var_name,
String::from_utf8_lossy(&bytes)
);
self.captures.push((var_name.clone(), bytes));
}

Expand Down
4 changes: 2 additions & 2 deletions controllers/guidance_ctrl/src/gctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ impl Runner {
.iter()
.rev()
.filter(|(name, _)| seen.insert(name))
.rev();
for (name, val) in captures {
.collect::<Vec<_>>();
for (name, val) in captures.iter().rev() {
let cap = Capture {
object: "capture",
name: name.clone(),
Expand Down
6 changes: 5 additions & 1 deletion controllers/guidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl TokenParser {

pub fn bytes_since(&self, mut idx: usize) -> &[u8] {
idx += self.grm_prefix.len();
if idx >= self.llm_tokens.len() {
if idx >= self.llm_bytes.len() {
return &[];
}
&self.llm_bytes[idx..]
Expand Down Expand Up @@ -182,6 +182,10 @@ impl TokenParser {
trie.token_set_dbg(&set)
);

if set.num_set() == 0 {
return MidProcessResult::stop();
}

return MidProcessResult::sample(set);
}
}
6 changes: 3 additions & 3 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[pytest]
testpaths = py/tests
addopts = -n 1
;[pytest]
;testpaths = py/tests
;addopts = -n 1
12 changes: 12 additions & 0 deletions scripts/test-guidance.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/sh

if [ "X$AZURE_GUIDANCE_URL" = "X" ] ; then
if [ "X$AICI_API_BASE" = "X" ] ; then
AICI_API_BASE="http://127.0.0.1:4242/v1/"
fi
AZURE_GUIDANCE_URL="$AICI_API_BASE"
fi
export AZURE_GUIDANCE_URL

cd $(dirname $0)/../py/guidance
pytest --selected_model azure_guidance tests/models/test_azure_guidance.py

0 comments on commit 1fd4e4d

Please sign in to comment.