From 6fa14ca699b8fe3524796428570b7dd9579d01e3 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 24 Jun 2024 23:05:35 +0000 Subject: [PATCH] fix eos handling --- controllers/llguidance_ctrl/run_g.py | 6 +++++- controllers/llguidance_ctrl/src/earley/parser.rs | 11 +++++++++-- controllers/llguidance_ctrl/src/tokenparser.rs | 8 ++++++-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/controllers/llguidance_ctrl/run_g.py b/controllers/llguidance_ctrl/run_g.py index c4483098..f3322304 100644 --- a/controllers/llguidance_ctrl/run_g.py +++ b/controllers/llguidance_ctrl/run_g.py @@ -263,7 +263,7 @@ def character_maker2(lm, id, description, valid_weapons): prompt = "" grm = guidance.json(schema={"type": "null"}) - assert grm.match("null") + # assert grm.match("null") grm = guidance.json( "OBJ", @@ -280,6 +280,10 @@ def character_maker2(lm, id, description, valid_weapons): # g = zero_or_more("a") + "b" # assert not g.match("b") + # lm = guidance.models.Mock(b"1234233234") + # grammar = one_or_more(select(["1", "2"])) + # lm += grammar + max_tokens = 250 serialized = grm.ll_serialize() diff --git a/controllers/llguidance_ctrl/src/earley/parser.rs b/controllers/llguidance_ctrl/src/earley/parser.rs index f0b6ad28..4a36af74 100644 --- a/controllers/llguidance_ctrl/src/earley/parser.rs +++ b/controllers/llguidance_ctrl/src/earley/parser.rs @@ -894,8 +894,15 @@ impl Parser { debug!(" flush_lexer() OK"); - if lexer_eos { - return true; + if mv == ModelVariable::eos_token() { + if lexer_eos { + return true; + } + // This is really for EOS tokens in the middle of the grammar + // that need to be eaten; so don't check for accepting state here + // if self.is_accepting() { + // return true; + // } } self.scratch.new_row(self.curr_row().last_item); diff --git a/controllers/llguidance_ctrl/src/tokenparser.rs b/controllers/llguidance_ctrl/src/tokenparser.rs index db8b13b1..44acd21e 100644 --- a/controllers/llguidance_ctrl/src/tokenparser.rs +++ b/controllers/llguidance_ctrl/src/tokenparser.rs @@ -202,6 +202,8 @@ impl TokenParser { trie.tokens_dbg(&arg.tokens) ); + let mut has_eos = false; + if arg.tokens.contains(&trie.eos_token()) { assert!(arg.tokens.len() == 1); if self.parser.scan_model_variable(ModelVariable::eos_token()) { @@ -209,7 +211,9 @@ impl TokenParser { infoln!(self, "scanned eos_token"); arg.tokens.clear(); } else { + infoln!(self, "didn't scan eos_token; saving"); arg.save_tokens(&mut self.llm_tokens); + has_eos = true; } } else { arg.save_tokens(&mut self.llm_tokens); @@ -345,10 +349,10 @@ impl TokenParser { let no_pending_bytes = !self.parser.has_pending_lexeme_bytes(); let is_accepting = no_pending_bytes && row_accepting; let can_advance = self.parser.can_advance(); - let inner_done = empty_token_prefix && is_accepting && !can_advance; + let inner_done = empty_token_prefix && is_accepting && (!can_advance || has_eos); infoln!( self, - "inner_done: {inner_done}; can_advance: {can_advance}; \ + "inner_done: {inner_done}; can_advance: {can_advance} (eos:{has_eos}); \ accept: {is_accepting} (row:{row_accepting} & lexer:{no_pending_bytes}); \ empty_token_prefix: {empty_token_prefix}" );