diff --git a/controllers/guidance_ctrl/src/gctrl.rs b/controllers/guidance_ctrl/src/gctrl.rs index 05892718..5427712e 100644 --- a/controllers/guidance_ctrl/src/gctrl.rs +++ b/controllers/guidance_ctrl/src/gctrl.rs @@ -20,6 +20,8 @@ macro_rules! infoln { pub struct Runner { tok_parser: TokenParser, reported_captures: usize, + text_ptr: usize, + token_ptr: usize, } #[derive(Serialize, Deserialize)] @@ -34,13 +36,17 @@ impl Runner { let guidance = base64::engine::general_purpose::STANDARD .decode(arg.guidance_b64) .expect("invalid base64"); + let tok_parser = TokenParser::from_guidance_protobuf( + Box::new(aici_abi::WasmTokenizerEnv::default()), + &guidance, + ) + .expect("invalid guidance protobuf"); + let token_ptr = tok_parser.num_tokens(); Runner { - tok_parser: TokenParser::from_guidance_protobuf( - Box::new(aici_abi::WasmTokenizerEnv::default()), - &guidance, - ) - .expect("invalid guidance protobuf"), + tok_parser, reported_captures: 0, + text_ptr: 0, + token_ptr, } } @@ -53,9 +59,20 @@ impl Runner { name: name.clone(), str: String::from_utf8_lossy(val).to_string(), hex: to_hex_string(val), + log_prob: 0.0, // TODO }; json_out(&cap); } + + let new_text = self.tok_parser.bytes_since(self.text_ptr); + if new_text.len() > 0 { + // TODO log_prob + let text = + Text::from_bytes(new_text, 0.0, self.tok_parser.num_tokens() - self.token_ptr); + json_out(&text); + self.text_ptr += new_text.len(); + self.token_ptr = self.tok_parser.num_tokens(); + } } } @@ -69,6 +86,7 @@ struct Capture { name: String, str: String, hex: String, + log_prob: f64, } #[derive(Serialize, Deserialize)] @@ -78,6 +96,27 @@ struct FinalText { hex: String, } +#[derive(Serialize, Deserialize)] +struct Text { + object: &'static str, // "text" + str: String, + hex: String, + log_prob: f64, + num_tokens: usize, +} + +impl Text { + pub fn from_bytes(bytes: &[u8], log_prob: f64, num_tokens: usize) -> Self { + Text { + object: "text", + str: String::from_utf8_lossy(bytes).to_string(), + hex: to_hex_string(bytes), + log_prob, + num_tokens, + } + } +} + impl FinalText { pub fn from_bytes(bytes: &[u8]) -> Self { FinalText { diff --git a/controllers/guidance_ctrl/src/tokenparser.rs b/controllers/guidance_ctrl/src/tokenparser.rs index ad834d89..7844b60f 100644 --- a/controllers/guidance_ctrl/src/tokenparser.rs +++ b/controllers/guidance_ctrl/src/tokenparser.rs @@ -38,10 +38,18 @@ impl TokenParser { }) } + pub fn num_tokens(&self) -> usize { + self.llm_tokens.len() + } + pub fn final_bytes(&self) -> &[u8] { &self.llm_bytes[self.grm_prefix.len()..] } + pub fn bytes_since(&self, idx: usize) -> &[u8] { + &self.llm_bytes[self.grm_prefix.len() + idx..] + } + pub fn process_prompt(&mut self, prompt: Vec) -> Vec { assert!(self.llm_tokens.is_empty());