-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] ggml: add Qwen2-VL example
Signed-off-by: dm4 <[email protected]>
- Loading branch information
Showing
5 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[package] | ||
name = "wasmedge-ggml-qwen2vl" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
serde_json = "1.0" | ||
wasmedge-wasi-nn = "0.7.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Qwen-2VL Example For WASI-NN with GGML Backend | ||
|
||
> [!NOTE] | ||
> Please refer to the [wasmedge-ggml/README.md](../README.md) for the general introduction and the setup of the WASI-NN plugin with GGML backend. This document will focus on the specific example of the Qwen2-VL model. | ||
## Get Qwen2-VL Model | ||
|
||
In this example, we are going to use the pre-converted [Qwen2-VL-2B](https://huggingface.co/second-state/Qwen2-VL-2B-Instruct-GGUF/tree/main) model. | ||
|
||
Download the model: | ||
|
||
```bash | ||
curl -LO https://huggingface.co/second-state/Qwen2-VL-2B-Instruct-GGUF/resolve/main/Qwen2-VL-2B-Instruct-vision-encoder.gguf | ||
curl -LO https://huggingface.co/second-state/Qwen2-VL-2B-Instruct-GGUF/resolve/main/Qwen2-VL-2B-Instruct-Q5_K_M.gguf | ||
``` | ||
|
||
## Prepare the Image | ||
|
||
Download the image you want to perform inference on: | ||
|
||
```bash | ||
curl -LO https://llava-vl.github.io/static/images/monalisa.jpg | ||
``` | ||
|
||
## Parameters | ||
|
||
> [!NOTE] | ||
> Please check the parameters section of [wasmedge-ggml/README.md](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters) first. | ||
In Qwen2-VL inference, we recommend to use the `ctx-size` at least `4096` for better results. | ||
|
||
```rust | ||
options.insert("ctx-size", Value::from(4096)); | ||
``` | ||
|
||
## Execute | ||
|
||
Execute the WASM with the `wasmedge` using the named model feature to preload a large model: | ||
|
||
```bash | ||
wasmedge --dir .:. \ | ||
--nn-preload default:GGML:AUTO:Qwen2-VL-2B-Instruct-Q5_K_M.gguf \ | ||
--env mmproj=Qwen2-VL-2B-Instruct-vision-encoder.gguf \ | ||
--env image=monalisa.jpg \ | ||
--env ctx_size=4096 \ | ||
wasmedge-ggml-qwen2vl.wasm default | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
use serde_json::Value; | ||
use std::collections::HashMap; | ||
use std::env; | ||
use std::io; | ||
use wasmedge_wasi_nn::{ | ||
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, | ||
TensorType, | ||
}; | ||
|
||
fn read_input() -> String { | ||
loop { | ||
let mut answer = String::new(); | ||
io::stdin() | ||
.read_line(&mut answer) | ||
.expect("Failed to read line"); | ||
if !answer.is_empty() && answer != "\n" && answer != "\r\n" { | ||
return answer.trim().to_string(); | ||
} | ||
} | ||
} | ||
|
||
fn get_options_from_env() -> HashMap<&'static str, Value> { | ||
let mut options = HashMap::new(); | ||
|
||
// Required parameters for llava | ||
if let Ok(val) = env::var("mmproj") { | ||
options.insert("mmproj", Value::from(val.as_str())); | ||
} else { | ||
eprintln!("Failed to get mmproj model."); | ||
std::process::exit(1); | ||
} | ||
if let Ok(val) = env::var("image") { | ||
options.insert("image", Value::from(val.as_str())); | ||
} else { | ||
eprintln!("Failed to get the target image."); | ||
std::process::exit(1); | ||
} | ||
|
||
// Optional parameters | ||
if let Ok(val) = env::var("enable_log") { | ||
options.insert("enable-log", serde_json::from_str(val.as_str()).unwrap()); | ||
} else { | ||
options.insert("enable-log", Value::from(false)); | ||
} | ||
if let Ok(val) = env::var("ctx_size") { | ||
options.insert("ctx-size", serde_json::from_str(val.as_str()).unwrap()); | ||
} else { | ||
options.insert("ctx-size", Value::from(4096)); | ||
} | ||
if let Ok(val) = env::var("n_gpu_layers") { | ||
options.insert("n-gpu-layers", serde_json::from_str(val.as_str()).unwrap()); | ||
} else { | ||
options.insert("n-gpu-layers", Value::from(0)); | ||
} | ||
options | ||
} | ||
|
||
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> { | ||
context.set_input(0, TensorType::U8, &[1], &data) | ||
} | ||
|
||
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String { | ||
// Preserve for 4096 tokens with average token length 6 | ||
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6; | ||
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE]; | ||
let mut output_size = context | ||
.get_output(index, &mut output_buffer) | ||
.expect("Failed to get output"); | ||
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size); | ||
|
||
String::from_utf8_lossy(&output_buffer[..output_size]).to_string() | ||
} | ||
|
||
fn get_output_from_context(context: &GraphExecutionContext) -> String { | ||
get_data_from_context(context, 0) | ||
} | ||
|
||
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value { | ||
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata") | ||
} | ||
|
||
fn main() { | ||
let args: Vec<String> = env::args().collect(); | ||
let model_name: &str = &args[1]; | ||
|
||
// Set options for the graph. Check our README for more details: | ||
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters | ||
let options = get_options_from_env(); | ||
// You could also set the options manually like this: | ||
|
||
// Create graph and initialize context. | ||
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO) | ||
.config(serde_json::to_string(&options).expect("Failed to serialize options")) | ||
.build_from_cache(model_name) | ||
.expect("Failed to build graph"); | ||
let mut context = graph | ||
.init_execution_context() | ||
.expect("Failed to init context"); | ||
|
||
// If there is a third argument, use it as the prompt and enter non-interactive mode. | ||
// This is mainly for the CI workflow. | ||
if args.len() >= 3 { | ||
let prompt = &args[2]; | ||
// Set the prompt. | ||
println!("Prompt:\n{}", prompt); | ||
let tensor_data = prompt.as_bytes().to_vec(); | ||
context | ||
.set_input(0, TensorType::U8, &[1], &tensor_data) | ||
.expect("Failed to set input"); | ||
println!("Response:"); | ||
|
||
// Get the number of input tokens and llama.cpp versions. | ||
let input_metadata = get_metadata_from_context(&context); | ||
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]); | ||
println!( | ||
"[INFO] llama_build_number: {}", | ||
input_metadata["llama_build_number"] | ||
); | ||
println!( | ||
"[INFO] Number of input tokens: {}", | ||
input_metadata["input_tokens"] | ||
); | ||
|
||
// Get the output. | ||
context.compute().expect("Failed to compute"); | ||
let output = get_output_from_context(&context); | ||
println!("{}", output.trim()); | ||
|
||
// Retrieve the output metadata. | ||
let metadata = get_metadata_from_context(&context); | ||
println!( | ||
"[INFO] Number of input tokens: {}", | ||
metadata["input_tokens"] | ||
); | ||
println!( | ||
"[INFO] Number of output tokens: {}", | ||
metadata["output_tokens"] | ||
); | ||
std::process::exit(0); | ||
} | ||
|
||
let mut saved_prompt = String::new(); | ||
let system_prompt = String::from("You are a helpful assistant."); | ||
let image_placeholder = "<image>"; | ||
|
||
loop { | ||
println!("USER:"); | ||
let input = read_input(); | ||
|
||
// Qwen2VL prompt format: <|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n<|vision_start|>{image_placeholder}<|vision_end|>{user_prompt}<|im_end|>\n<|im_start|>assistant\n"; | ||
if saved_prompt.is_empty() { | ||
saved_prompt = format!( | ||
"<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n<|vision_start|>{}<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", | ||
system_prompt, image_placeholder, input | ||
); | ||
} else { | ||
saved_prompt = format!( | ||
"{}<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", | ||
saved_prompt, input | ||
); | ||
} | ||
|
||
// Set prompt to the input tensor. | ||
set_data_to_context(&mut context, saved_prompt.as_bytes().to_vec()) | ||
.expect("Failed to set input"); | ||
|
||
// Execute the inference. | ||
let mut reset_prompt = false; | ||
match context.compute() { | ||
Ok(_) => (), | ||
Err(Error::BackendError(BackendError::ContextFull)) => { | ||
println!("\n[INFO] Context full, we'll reset the context and continue."); | ||
reset_prompt = true; | ||
} | ||
Err(Error::BackendError(BackendError::PromptTooLong)) => { | ||
println!("\n[INFO] Prompt too long, we'll reset the context and continue."); | ||
reset_prompt = true; | ||
} | ||
Err(err) => { | ||
println!("\n[ERROR] {}", err); | ||
std::process::exit(1); | ||
} | ||
} | ||
|
||
// Retrieve the output. | ||
let mut output = get_output_from_context(&context); | ||
println!("ASSISTANT:\n{}", output.trim()); | ||
|
||
// Update the saved prompt. | ||
if reset_prompt { | ||
saved_prompt.clear(); | ||
} else { | ||
output = output.trim().to_string(); | ||
saved_prompt = format!("{}{}<|im_end|>\n", saved_prompt, output); | ||
} | ||
} | ||
} |
Binary file not shown.