Skip to content

Commit

Permalink
add LLTokenizer.test_trace_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 28, 2024
1 parent ea95552 commit dd14b0d
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 196 deletions.
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,10 @@
"files.readonlyInclude": {
"**/dist/*": true,
"**/aici-types.d.ts": true
}
},
"python.testing.pytestArgs": [
"py/guidance"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
31 changes: 18 additions & 13 deletions controllers/aici_abi/src/toktree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,30 +241,35 @@ impl TokTrie {
vec![0.0; self.vocab_size() + 1]
}

pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
toks.iter()
.map(|t| {
let s = self.token_dbg(*t);
if s.starts_with("\"") {
self.token_str(*t)
} else {
format!("≺{}≻", s)
}
})
.collect::<Vec<_>>()
.join("‧")
}

pub fn tokens_dbg(&self, toks: &[u32]) -> String {
let minimal = false;
let sep = "‧";
let joined = toks
.iter()
.map(|t| {
let s = self.token_dbg(*t);
if s.starts_with("\"") {
let inner = s[1..s.len() - 1].to_string();
let b = s.as_bytes();
// for " [\w]..." and " " the sep in front is implicit
if minimal && b[1] == b' ' && ((b[2] as char).is_alphanumeric() || b.len() == 3)
{
inner
} else {
format!("{}{}", sep, inner)
}
s[1..s.len() - 1].to_string()
} else {
format!("≺{}≻", s)
}
})
.collect::<Vec<_>>()
.join("");
format!("\"{}\"", joined.trim_start_matches(sep))
.join("‧");

format!("\"{}\"", joined)
}

pub fn token_dbg(&self, idx: u32) -> String {
Expand Down
250 changes: 70 additions & 180 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64
import ujson as json
import binascii

import os

import guidance
from guidance import (
Expand All @@ -24,146 +24,7 @@
)


@guidance(stateless=True)
def number(lm):
n = one_or_more(select(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]))
return lm + select(["-" + n, n])


@guidance(stateless=True)
def identifier(lm):
letter = select([byte_range(b"a", b"z"), byte_range(b"A", b"Z"), "_"])
num = byte_range(b"0", b"9")
return lm + letter + zero_or_more(select([letter, num]))


@guidance(stateless=True)
def assignment_stmt(lm):
return lm + identifier() + " = " + expression()


@guidance(stateless=True)
def while_stmt(lm):
return lm + "while " + expression() + ":" + stmt()


@guidance(stateless=True)
def stmt(lm):
return lm + select([assignment_stmt(), while_stmt()])


@guidance(stateless=True)
def operator(lm):
return lm + select(["+", "*", "**", "/", "-"])


@guidance(stateless=True)
def expression(lm):
return lm + select(
[
identifier(),
expression()
+ zero_or_more(" ")
+ operator()
+ zero_or_more(" ")
+ expression(),
"(" + expression() + ")",
]
)


@guidance(stateless=True)
def json_string(lm):
return lm + lexeme(r'"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)*"')


@guidance(stateless=True)
def json_number(lm):
return lm + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?")


@guidance(stateless=True)
def json_value(lm):
return lm + select(
[
json_string(),
json_number(),
json_object(),
json_array(),
"true",
"false",
"null",
]
)


@guidance(stateless=True)
def json_member(lm):
return lm + json_string() + ":" + json_value()


@guidance(stateless=True)
def json_object(lm):
return lm + "{" + optional(json_member() + one_or_more("," + json_member())) + "}"


@guidance(stateless=True)
def json_array(lm):
return lm + "[" + optional(json_value() + one_or_more("," + json_value())) + "]"


@guidance(stateless=True)
def gen_json_object(lm, name: str, max_tokens=100000000):
grm = greedy_grammar(body=json_object(), skip_regex=r"[\x20\x0A\x0D\x09]+")
return lm + grm


def main():
grm = (
"Here's a sample arithmetic expression: "
+ capture(expression(), "expr")
+ " = "
+ capture(number(), "num")
)
grm = (
"<joke>Parallel lines have so much in common. It’s a shame they’ll never meet.</joke>\nScore: 8/10\n"
+ "<joke>"
+ capture(gen(regex=r"[A-Z\(].*", max_tokens=50, stop="</joke>"), "joke")
+ "</joke>\nScore: "
+ capture(gen(regex=r"\d{1,3}"), "score")
+ "/10\n"
)
grm = "this is a test" + gen("test", max_tokens=10)
grm = "Tweak this proverb to apply to model instructions instead.\n" + gen(
"verse", max_tokens=2
)
grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")
grm = "<color>red</color>\n<color>" + gen(stop="</color>") + " and test2"

lm = "Here's a "
lm += select(["joke", "poem"], name="type")
lm += ": "
lm += gen("words", regex=r"[A-Z ]+", stop="\n")
grm = lm

@guidance(stateless=True, dedent=True)
def character_maker(lm, id, description, valid_weapons):
lm += f"""\
The following is a character profile for an RPG game in JSON format.
```json
{{
"id": "{id}",
"description": "{description}",
"name": "{gen('name', stop='"')}",
"age": {gen('age', regex='[0-9]+', stop=',')},
"armor": "{select(options=['leather', 'chainmail', 'plate'], name='armor')}",
"weapon": "{select(options=valid_weapons, name='weapon')}",
"class": "{gen('class', stop='"')}",
"mantra": "{gen('mantra', stop='"')}",
"strength": {gen('strength', regex='[0-9]+', stop=',')},
"items": ["{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}"]
}}```"""
return lm

@guidance(stateless=True, dedent=True)
def character_maker2(lm, id, description, valid_weapons):
Expand Down Expand Up @@ -234,13 +95,6 @@ def character_maker2(lm, id, description, valid_weapons):
)
)

prompt = "Three things about J. Random Hacker:\n"
grm = (
gen_json_object("hacker", max_tokens=150)
+ "\nScore (0-9): "
+ gen("score", regex=r"[0-9]")
)

grm = character_maker2(1, "A nimble fighter", ["axe", "sword", "bow"])
prompt = ""

Expand Down Expand Up @@ -280,16 +134,33 @@ def character_maker2(lm, id, description, valid_weapons):
prompt = ""
grm = optional("A")

grm = one_or_more(gen(regex="[a-z]"))
grm = "A odd number is " + gen(
"number", regex="[0-9]+", max_tokens=5, temperature=0
)


grm = (
"Q: Are dolphins fish?\nA: "
+ gen("dolphins", regex="Yes|No", max_tokens=10)
+ "\nQ: Are sharks fish?\nA: "
+ gen("sharks", regex="Yes|No", max_tokens=10)
)

grm = one_or_more(gen(regex="[a-z]"))
grm = (
"Power frequency is "
+ gen("number", regex="[0-9]+", max_tokens=5, temperature=0)
+ "Hz; voltage is "
+ gen("number", regex="[0-9]+", max_tokens=5, temperature=0)
+ "V"
)

grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5)

grm = "Dolphin name: " + commit_point(
'"' + byte_range(b"A", b"Z") + one_or_more(byte_range(b"a", b"z")) + '"'
) + ","

# grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5)

# g = zero_or_more("a") + "b"
# assert g.match("b")
Expand All @@ -299,13 +170,17 @@ def character_maker2(lm, id, description, valid_weapons):
# grammar = one_or_more(select(["1", "2"]))
# lm += grammar

# grm = greedy_grammar(
# body = lexeme("[0-9]+")
# )

max_tokens = 250

serialized = grm.ll_serialize()

# with open("tmp/long_json_grammar_req.json", "r") as f:
# with open("tmp/email_regex_grammar.json", "r") as f:
# max_tokens = 2000
# # with open("tmp/email_regex_grammar.json", "r") as f:
# max_tokens = 1000
# serialized = json.load(f)

x_serialized = {
Expand All @@ -325,35 +200,8 @@ def character_maker2(lm, id, description, valid_weapons):
]
}

x_serialized = {
"grammars": [
{
"greedy_lexer": False,
"nodes": [
{
"GenGrammar": {
"grammar": 1,
"stop_rx": "",
"no_initial_skip": True,
"temperature": 0.0,
}
}
],
"rx_nodes": [],
},
{
"greedy_lexer": True,
"greedy_skip_rx": "[\\x20\\x0A\\x0D\\x09]+",
"nodes": [
{"Lexeme": {"rx": "-?(?:0|[1-9][0-9]*)", "contextual": False}}
# {"Lexeme": {"rx": "[ab][ab]", "contextual": False}}
],
"rx_nodes": [],
},
]
}

serialized["max_tokens"] = max_tokens
serialized["test_trace"] = True
llguidance_json = {"grammar": serialized}

llguidance_arg = json.dumps(llguidance_json, indent=1)
Expand All @@ -371,7 +219,10 @@ def character_maker2(lm, id, description, valid_weapons):
# script = f.read()
# grm = "```python\n" + substring(script[0:1400])

mod_id = pyaici.cli.build_rust(".", features=["logging"])
features = ["logging"]
if "FAST" in os.environ:
features = []
mod_id = pyaici.cli.build_rust(".", features=features)
if "127.0.0.1" in pyaici.rest.base_url:
pyaici.rest.tag_module(mod_id, ["llguidance_ctrl-latest", "llguidance"])
pyaici.rest.log_level = 2
Expand All @@ -388,6 +239,8 @@ def character_maker2(lm, id, description, valid_weapons):
print("Storage:", res["storage"])
print()

testcase_from_logs(res["logs"][0])

text = b""
captures = {}
for j in res["json_out"][0]:
Expand All @@ -402,4 +255,41 @@ def character_maker2(lm, id, description, valid_weapons):
print()


def testcase_from_logs(logs: str):
sep = "‧"
pairs = []
prev_res = None
prompt = None
for line in logs.split("\n"):
if line.startswith("TEST: "):
obj = json.loads(line[6:])
if prompt is None:
prompt = obj["res_prompt"]
continue
if prev_res:
pairs.append((prev_res, obj["arg"]))
prev_res = obj["res"]
print(obj)
assert prev_res == "stop"
testcase = [prompt]
gen_tokens = []

def flush_gen_tokens():
testcase.append(sep.join(gen_tokens))
gen_tokens.clear()

for res, arg in pairs:
if res["sample_mask"]:
gen_tokens.append(arg["tokens"])
else:
t0 = res["splices"][0]["tokens"]
assert t0 == arg["tokens"]
flush_gen_tokens()
testcase.append(t0)
if gen_tokens:
flush_gen_tokens()

print("Testcase:", testcase)


main()
2 changes: 2 additions & 0 deletions controllers/llguidance_ctrl/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use serde::{Deserialize, Serialize};
pub struct TopLevelGrammar {
pub grammars: Vec<GrammarWithLexer>,
pub max_tokens: Option<usize>,
#[serde(default)]
pub test_trace: bool,
}

pub const DEFAULT_CONTEXTUAL: bool = true;
Expand Down
Loading

0 comments on commit dd14b0d

Please sign in to comment.