Skip to content

Commit

Permalink
[Example] ggml: show the number of input tokens before compute()
Browse files Browse the repository at this point in the history
Signed-off-by: dm4 <[email protected]>
  • Loading branch information
dm4 committed Dec 25, 2023
1 parent f0f4910 commit cfb35c3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions wasmedge-ggml-llama-interactive/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ fn main() {
let args: Vec<String> = env::args().collect();
let model_name: &str = &args[1];

// Preserve for 4096 tokens with average token length 6
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;

let mut options = json!({});
match env::var("enable_log") {
Ok(val) => options["enable-log"] = serde_json::from_str(val.as_str()).unwrap(),
Expand Down Expand Up @@ -99,10 +102,9 @@ fn main() {
.unwrap();
println!("Response:");
context.compute().unwrap();
let max_output_size = 4096 * 6;
let mut output_buffer = vec![0u8; max_output_size];
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut output_size = context.get_output(0, &mut output_buffer).unwrap();
output_size = std::cmp::min(max_output_size, output_size);
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
println!("{}", output.trim());
} else {
Expand All @@ -121,6 +123,20 @@ fn main() {
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
.unwrap();

// Get the number of input tokens.
let mut input_metadata_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut input_metadata_size =
context.get_output(1, &mut input_metadata_buffer).unwrap();
input_metadata_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, input_metadata_size);
let input_metadata_str =
String::from_utf8_lossy(&input_metadata_buffer[..input_metadata_size]).to_string();
let input_metadata: Value = serde_json::from_str(&input_metadata_str).unwrap();
if let Some(enable_log) = options["enable-log"].as_bool() {
if enable_log {
println!("Number of input tokens: {}", input_metadata["input_tokens"]);
}
}

println!("Answer:");

let mut output = String::new();
Expand All @@ -139,10 +155,9 @@ fn main() {
}
}
// Retrieve the output.
let max_output_size = 4096 * 6;
let mut output_buffer = vec![0u8; max_output_size];
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut output_size = context.get_output_single(0, &mut output_buffer).unwrap();
output_size = std::cmp::min(max_output_size, output_size);
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
let token = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
print!("{}", token);
io::stdout().flush().unwrap();
Expand All @@ -154,10 +169,9 @@ fn main() {
context.compute().unwrap();

// Retrieve the output.
let max_output_size = 4096 * 6;
let mut output_buffer = vec![0u8; max_output_size];
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut output_size = context.get_output(0, &mut output_buffer).unwrap();
output_size = std::cmp::min(max_output_size, output_size);
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
if let Some(is_stream) = options["stream-stdout"].as_bool() {
if !is_stream {
Expand All @@ -172,10 +186,9 @@ fn main() {
saved_prompt = format!("{} {} ", saved_prompt, output.trim());

// Retrieve the output metadata.
let max_metadata_size = 4096 * 6;
let mut metadata_buffer = vec![0u8; max_metadata_size];
let mut metadata_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut metadata_size = context.get_output(1, &mut metadata_buffer).unwrap();
metadata_size = std::cmp::min(max_metadata_size, metadata_size);
metadata_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, metadata_size);
let metadata_str =
String::from_utf8_lossy(&metadata_buffer[..metadata_size]).to_string();
let metadata: Value = serde_json::from_str(&metadata_str).unwrap();
Expand Down
Binary file not shown.

0 comments on commit cfb35c3

Please sign in to comment.