diff --git a/.github/workflows/llama.yml b/.github/workflows/llama.yml index 7b0b0da..8c88255 100644 --- a/.github/workflows/llama.yml +++ b/.github/workflows/llama.yml @@ -327,6 +327,19 @@ jobs: default \ $'<|user|>\nWhat is the capital of Japan?<|end|>\n<|assistant|>' + - name: JSON Schema + run: | + test -f ~/.wasmedge/env && source ~/.wasmedge/env + cd wasmedge-ggml/json-schema + curl -LO https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf + cargo build --target wasm32-wasi --release + time wasmedge --dir .:. \ + --env n_gpu_layers="$NGL" \ + --nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \ + target/wasm32-wasi/release/wasmedge-ggml-json-schema.wasm \ + default \ + $'[INST] <>\nYou are a helpful, respectful and honest assistant. Always output JSON format string.\n<>\nGive me a JSON array of Apple products.[/INST]' + - name: Build llama-stream run: | cd wasmedge-ggml/llama-stream diff --git a/wasmedge-ggml/README.md b/wasmedge-ggml/README.md index 4ed2380..ca8148d 100644 --- a/wasmedge-ggml/README.md +++ b/wasmedge-ggml/README.md @@ -151,6 +151,8 @@ Supported parameters include: - `threads`: Set the number of threads for the inference, the same as the `--threads` parameter in llama.cpp. - `mmproj`: Set the path to the multimodal projector file for llava, the same as the `--mmproj` parameter in llama.cpp. - `image`: Set the path to the image file for llava, the same as the `--image` parameter in llama.cpp. +- `grammar`: Specify a grammar to constrain model output to a specific format, the same as the `--grammar` parameter in llama.cpp. +- `json-schema`: Specify a JSON schema to constrain model output, the same as the `--json-schema` parameter in llama.cpp. (For more detailed instructions on usage or default values for the parameters, please refer to [WasmEdge](https://github.com/WasmEdge/WasmEdge/blob/master/plugins/wasi_nn/ggml.cpp).) diff --git a/wasmedge-ggml/json-schema/Cargo.toml b/wasmedge-ggml/json-schema/Cargo.toml new file mode 100644 index 0000000..da3fc2f --- /dev/null +++ b/wasmedge-ggml/json-schema/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "wasmedge-ggml-json-schema" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde_json = "1.0" +wasmedge-wasi-nn = "0.8.0" diff --git a/wasmedge-ggml/json-schema/README.md b/wasmedge-ggml/json-schema/README.md new file mode 100644 index 0000000..d1db910 --- /dev/null +++ b/wasmedge-ggml/json-schema/README.md @@ -0,0 +1,61 @@ +# JSON Schema Example For WASI-NN with GGML Backend + +> [!NOTE] +> Please refer to the [wasmedge-ggml/README.md](../README.md) for the general introduction and the setup of the WASI-NN plugin with GGML backend. This document will focus on the specific example of using json schema in ggml. + +## Get the Model + +In this example, we are going to use the [llama-2-7b](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF) model. Please note that we are not using a fine-tuned chat model. + +```bash +curl -LO https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf +``` + +## Parameters + +> [!NOTE] +> Please check the parameters section of [wasmedge-ggml/README.md](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters) first. + +In this example, we are going to use the `json-schema` option to constrain the model to generate the JSON output in a specific format. + +You can check [the documents at llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/main#grammars--json-schemas) for more details about this. + +## Execute + +```console +$ wasmedge --dir .:. \ + --env n_predict=99 \ + --nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \ + wasmedge-ggml-json-schema.wasm default + +USER: +Give me a JSON array of Apple products. +ASSISTANT: +[ +{ +"productId": 1, +"productName": "iPhone 12 Pro", +"price": 799.99 +}, +{ +"productId": 2, +"productName": "iPad Air", +"price": 599.99 +}, +{ +"productId": 3, +"productName": "MacBook Air", +"price": 999.99 +}, +{ +"productId": 4, +"productName": "Apple Watch Series 7", +"price": 399.99 +}, +{ +"productId": 5, +"productName": "AirPods Pro", +"price": 249.99 +} +] +``` diff --git a/wasmedge-ggml/json-schema/src/main.rs b/wasmedge-ggml/json-schema/src/main.rs new file mode 100644 index 0000000..ff3e272 --- /dev/null +++ b/wasmedge-ggml/json-schema/src/main.rs @@ -0,0 +1,202 @@ +use serde_json::json; +use serde_json::Value; +use std::env; +use std::io; +use wasmedge_wasi_nn::{ + self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, + TensorType, +}; + +fn read_input() -> String { + loop { + let mut answer = String::new(); + io::stdin() + .read_line(&mut answer) + .expect("Failed to read line"); + if !answer.is_empty() && answer != "\n" && answer != "\r\n" { + return answer.trim().to_string(); + } + } +} + +fn get_options_from_env() -> Value { + let mut options = json!({}); + if let Ok(val) = env::var("enable_log") { + options["enable-log"] = serde_json::from_str(val.as_str()) + .expect("invalid value for enable-log option (true/false)") + } else { + options["enable-log"] = serde_json::from_str("false").unwrap() + } + if let Ok(val) = env::var("n_gpu_layers") { + options["n-gpu-layers"] = + serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer") + } else { + options["n-gpu-layers"] = serde_json::from_str("0").unwrap() + } + if let Ok(val) = env::var("n_predict") { + options["n-predict"] = + serde_json::from_str(val.as_str()).expect("invalid n-predict value (unsigned integer") + } + if let Ok(val) = env::var("json_schema") { + options["json-schema"] = + serde_json::from_str(val.as_str()).expect("invalid n-predict value (unsigned integer") + } + + options +} + +fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec) -> Result<(), Error> { + context.set_input(0, TensorType::U8, &[1], &data) +} + +#[allow(dead_code)] +fn set_metadata_to_context( + context: &mut GraphExecutionContext, + data: Vec, +) -> Result<(), Error> { + context.set_input(1, TensorType::U8, &[1], &data) +} + +fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String { + // Preserve for 4096 tokens with average token length 6 + const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6; + let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE]; + let mut output_size = context + .get_output(index, &mut output_buffer) + .expect("Failed to get output"); + output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size); + + return String::from_utf8_lossy(&output_buffer[..output_size]).to_string(); +} + +fn get_output_from_context(context: &GraphExecutionContext) -> String { + get_data_from_context(context, 0) +} + +fn get_metadata_from_context(context: &GraphExecutionContext) -> Value { + serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata") +} + +const JSON_SCHEMA: &str = r#" +{ + "items": { + "title": "Product", + "description": "A product from the catalog", + "type": "object", + "properties": { + "productId": { + "description": "The unique identifier for a product", + "type": "integer" + }, + "productName": { + "description": "Name of the product", + "type": "string" + }, + "price": { + "description": "The price of the product", + "type": "number", + "exclusiveMinimum": 0 + } + }, + "required": [ + "productId", + "productName", + "price" + ] + }, + "minItems": 5 +} +"#; + +fn main() { + let args: Vec = env::args().collect(); + let model_name: &str = &args[1]; + + // Set options for the graph. Check our README for more details: + // https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters + let mut options = get_options_from_env(); + + // Add grammar for JSON output. + // Check [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars) for more details. + options["json-schema"] = JSON_SCHEMA.into(); + + // Make the output more consistent. + options["temp"] = json!(0.1); + + // Create graph and initialize context. + let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO) + .config(serde_json::to_string(&options).expect("Failed to serialize options")) + .build_from_cache(model_name) + .expect("Failed to build graph"); + let mut context = graph + .init_execution_context() + .expect("Failed to init context"); + + // If there is a third argument, use it as the prompt and enter non-interactive mode. + // This is mainly for the CI workflow. + if args.len() >= 3 { + let prompt = &args[2]; + // Set the prompt. + println!("Prompt:\n{}", prompt); + let tensor_data = prompt.as_bytes().to_vec(); + context + .set_input(0, TensorType::U8, &[1], &tensor_data) + .expect("Failed to set input"); + println!("Response:"); + + // Get the number of input tokens and llama.cpp versions. + let input_metadata = get_metadata_from_context(&context); + println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]); + println!( + "[INFO] llama_build_number: {}", + input_metadata["llama_build_number"] + ); + println!( + "[INFO] Number of input tokens: {}", + input_metadata["input_tokens"] + ); + + // Get the output. + context.compute().expect("Failed to compute"); + let output = get_output_from_context(&context); + println!("{}", output.trim()); + + // Retrieve the output metadata. + let metadata = get_metadata_from_context(&context); + println!( + "[INFO] Number of input tokens: {}", + metadata["input_tokens"] + ); + println!( + "[INFO] Number of output tokens: {}", + metadata["output_tokens"] + ); + std::process::exit(0); + } + + loop { + println!("USER:"); + let input = read_input(); + + // Set prompt to the input tensor. + set_data_to_context(&mut context, input.as_bytes().to_vec()).expect("Failed to set input"); + + // Execute the inference. + match context.compute() { + Ok(_) => (), + Err(Error::BackendError(BackendError::ContextFull)) => { + println!("\n[INFO] Context full, we'll reset the context and continue."); + } + Err(Error::BackendError(BackendError::PromptTooLong)) => { + println!("\n[INFO] Prompt too long, we'll reset the context and continue."); + } + Err(err) => { + println!("\n[ERROR] {}", err); + } + } + + // Retrieve the output. + let output = get_output_from_context(&context); + println!("ASSISTANT:\n{}", output.trim()); + } +} diff --git a/wasmedge-ggml/json-schema/wasmedge-ggml-json-schema.wasm b/wasmedge-ggml/json-schema/wasmedge-ggml-json-schema.wasm new file mode 100755 index 0000000..f32f852 Binary files /dev/null and b/wasmedge-ggml/json-schema/wasmedge-ggml-json-schema.wasm differ