From 56d497af0ca77b07188d047cf9961a422d02cfb9 Mon Sep 17 00:00:00 2001 From: Yiwei Yang <40686366+victoryang00@users.noreply.github.com> Date: Tue, 24 Dec 2024 19:29:31 +0800 Subject: [PATCH] [Example] Give an example of llama tflite (#165) * add llama * add llama 2 * fix encoding * fix error * Update README.md * Update Cargo.toml --- openvino-mobilenet-image/rust/Cargo.toml | 2 +- openvino-mobilenet-raw/rust/Cargo.toml | 2 +- .../openvino-road-seg-adas/Cargo.toml | 2 +- tflite-birds_v1-image/rust/Cargo.toml | 2 +- wasmedge-tf-llama/README.md | 59 +++++ wasmedge-tf-llama/rust/Cargo.toml | 17 ++ wasmedge-tf-llama/rust/src/main.rs | 227 ++++++++++++++++++ 7 files changed, 307 insertions(+), 4 deletions(-) create mode 100644 wasmedge-tf-llama/README.md create mode 100644 wasmedge-tf-llama/rust/Cargo.toml create mode 100644 wasmedge-tf-llama/rust/src/main.rs diff --git a/openvino-mobilenet-image/rust/Cargo.toml b/openvino-mobilenet-image/rust/Cargo.toml index d09e0a4..c7083fb 100644 --- a/openvino-mobilenet-image/rust/Cargo.toml +++ b/openvino-mobilenet-image/rust/Cargo.toml @@ -8,6 +8,6 @@ publish = false [dependencies] image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] } -wasi-nn = { version = "0.4.0" } +wasi-nn = { version = "0.6.0" } [workspace] diff --git a/openvino-mobilenet-raw/rust/Cargo.toml b/openvino-mobilenet-raw/rust/Cargo.toml index 8eab25b..3f00aec 100644 --- a/openvino-mobilenet-raw/rust/Cargo.toml +++ b/openvino-mobilenet-raw/rust/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" publish = false [dependencies] -wasi-nn = { version = "0.4.0" } +wasi-nn = { version = "0.6.0" } [workspace] diff --git a/openvino-road-segmentation-adas/openvino-road-seg-adas/Cargo.toml b/openvino-road-segmentation-adas/openvino-road-seg-adas/Cargo.toml index 998f391..93f91e0 100644 --- a/openvino-road-segmentation-adas/openvino-road-seg-adas/Cargo.toml +++ b/openvino-road-segmentation-adas/openvino-road-seg-adas/Cargo.toml @@ -5,5 +5,5 @@ name = "openvino-road-seg-adas" version = "0.2.0" [dependencies] -wasi-nn = "0.4.0" +wasi-nn = "0.6.0" image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] } diff --git a/tflite-birds_v1-image/rust/Cargo.toml b/tflite-birds_v1-image/rust/Cargo.toml index 572ecb9..9e89e87 100644 --- a/tflite-birds_v1-image/rust/Cargo.toml +++ b/tflite-birds_v1-image/rust/Cargo.toml @@ -8,6 +8,6 @@ publish = false [dependencies] image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] } -wasi-nn = "0.4.0" +wasi-nn = "0.6.0" [workspace] diff --git a/wasmedge-tf-llama/README.md b/wasmedge-tf-llama/README.md new file mode 100644 index 0000000..6ca9619 --- /dev/null +++ b/wasmedge-tf-llama/README.md @@ -0,0 +1,59 @@ +# Llama Example For WasmEdge-Tensorflow plug-in + +This package is a high-level Rust bindings for [WasmEdge-TensorFlow plug-in](https://wasmedge.org/docs/develop/rust/tensorflow) example of Mobilenet. + +## Dependencies + +This crate depends on the `wasmedge_tensorflow_interface` in the `Cargo.toml`: + +```toml +[dependencies] +wasmedge_tensorflow_interface = "0.3.0" +wasi-nn = "0.1.0" # Ensure you use the latest version +thiserror = "1.0" +bytemuck = "1.13.1" +log = "0.4.19" +env_logger = "0.10.0" +anyhow = "1.0.79" +``` + +## Build + +Compile the application to WebAssembly: + +```bash +cd rust && cargo build --target=wasm32-wasip1 --release +``` + +The output WASM file will be at [`rust/target/wasm32-wasip1/release/wasmedge-tf-example-llama.wasm`](wasmedge-tf-example-llama.wasm). +To speed up the image processing, we can enable the AOT mode in WasmEdge with: + +```bash +wasmedge compile rust/target/wasm32-wasi/release/wasmedge-tf-example-llama.wasm wasmedge-tf-example-llama_aot.wasm +``` + +## Run + +The frozen `tflite` model should be translated through `ai_edge_torch` and HuggingFace. + +### Execute + +Users should [install the WasmEdge with WasmEdge-TensorFlow and WasmEdge-Image plug-ins](https://wasmedge.org/docs/start/install#wasmedge-tensorflow-plug-in). + +```bash +curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- --plugins wasmedge_tensorflow wasmedge_image +``` + +Execute the WASM with the `wasmedge` with Tensorflow Lite supporting: + +```bash +wasmedge --dir .:. wasmedge-tf-example-llama.wasm ./llama_1b_q8_ekv1280.tflite +``` + +You will get the output: + +```console +Input the Chatbot: +Hello world +Hello world! +``` diff --git a/wasmedge-tf-llama/rust/Cargo.toml b/wasmedge-tf-llama/rust/Cargo.toml new file mode 100644 index 0000000..229ec0e --- /dev/null +++ b/wasmedge-tf-llama/rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "wasmedge-tf-example-llama" +version = "0.1.0" +authors = ["victoryang00"] +readme = "README.md" +edition = "2021" +publish = false + +[dependencies] +wasmedge_tensorflow_interface = "0.3.0" +wasi-nn = "0.1.0" # Ensure you use the latest version +thiserror = "1.0" +bytemuck = "1.13.1" +log = "0.4.19" +env_logger = "0.10.0" +anyhow = "1.0.79" +[workspace] diff --git a/wasmedge-tf-llama/rust/src/main.rs b/wasmedge-tf-llama/rust/src/main.rs new file mode 100644 index 0000000..8e855af --- /dev/null +++ b/wasmedge-tf-llama/rust/src/main.rs @@ -0,0 +1,227 @@ +#![feature(str_as_str)] +use bytemuck::{cast_slice, cast_slice_mut}; +use std::collections::HashMap; +use std::env; +use std::fs::File; +use std::io::Read; +use std::io::{self, BufRead, Write}; +use std::process; +use thiserror::Error; +use wasi_nn; + +// Define a custom error type to handle multiple error sources +#[derive(Error, Debug)] +pub enum ChatbotError { + #[error("IO Error: {0}")] + Io(#[from] std::io::Error), + + #[error("WASI NN Error: {0}")] + WasiNn(String), + + #[error("Other Error: {0}")] + Other(String), +} + +impl From for ChatbotError { + fn from(error: wasi_nn::Error) -> Self { + ChatbotError::WasiNn(error.to_string()) + } +} + +type Result = std::result::Result; + +// Tokenizer Struct +struct Tokenizer { + vocab: HashMap, + vocab_reverse: HashMap, + next_id: i32, +} + +impl Tokenizer { + fn new(initial_vocab: Vec<(&str, i32)>) -> Self { + let mut vocab = HashMap::new(); + let mut vocab_reverse = HashMap::new(); + let mut next_id = 1; + + for (word, id) in initial_vocab { + vocab.insert(word.to_string(), id); + vocab_reverse.insert(id, word.to_string()); + if id >= next_id { + next_id = id + 1; + } + } + + // Add special tokens + vocab.insert("".to_string(), 0); + vocab_reverse.insert(0, "".to_string()); + vocab.insert("".to_string(), -1); + vocab_reverse.insert(-1, "".to_string()); + + Tokenizer { + vocab, + vocab_reverse, + next_id, + } + } + + fn tokenize(&mut self, input: &str) -> Vec { + input + .split_whitespace() + .map(|word| { + self.vocab.get(word).cloned().unwrap_or_else(|| { + let id = self.next_id; + self.vocab.insert(word.to_string(), id); + self.vocab_reverse.insert(id, word.to_string()); + self.next_id += 1; + id + }) + }) + .collect() + } + + fn tokenize_with_fixed_length(&mut self, input: &str, max_length: usize) -> Vec { + let mut tokens = self.tokenize(input); + + if tokens.len() > max_length { + tokens.truncate(max_length); + } else if tokens.len() < max_length { + tokens.extend(vec![-1; max_length - tokens.len()]); // Assuming -1 is the token + } + + tokens + } + + fn detokenize(&self, tokens: &[i32]) -> String { + tokens + .iter() + .map(|&token| self.vocab_reverse.get(&token).map_or("", |v| v).as_str()) + .collect::>() + .join(" ") + } +} + +// Function to load the TFLite model +fn load_model(model_path: &str) -> Result { + let mut file = File::open(model_path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + let model_segments = &[&buffer[..]]; + Ok(unsafe { wasi_nn::load(model_segments, 4, wasi_nn::EXECUTION_TARGET_CPU)? }) +} + +// Function to initialize the execution context +fn init_context(graph: wasi_nn::Graph) -> Result { + Ok(unsafe { wasi_nn::init_execution_context(graph)? }) +} + +fn main() -> Result<()> { + // Parse command-line arguments + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} ", args[0]); + process::exit(1); + } + let model_file = &args[1]; + + // Load the model + let graph = load_model(model_file)?; + + // Initialize execution context + let ctx = init_context(graph)?; + + let mut stdout = io::stdout(); + let stdin = io::stdin(); + + // Initialize KV cache data + println!("Chatbot is ready! Type your messages below:"); + + for line in stdin.lock().lines() { + let user_input = line?; + if user_input.trim().is_empty() { + continue; + } + if user_input.to_lowercase() == "exit" { + break; + } + // Initialize tokenizer + let initial_vocab = vec![ + ("hello", 1), + ("world", 2), + ("this", 3), + ("is", 4), + ("a", 5), + ("test", 6), + ("", -1), + ]; + + let mut tokenizer = Tokenizer::new(initial_vocab); + // let user_input = "hello world this is a test with more words"; + + // Tokenize with fixed length + let max_length = 655360; + let tokens = tokenizer.tokenize_with_fixed_length(&user_input, max_length); + let tokens_dims = &[1u32, max_length as u32]; + let tokens_tensor = wasi_nn::Tensor { + dimensions: tokens_dims, + r#type: wasi_nn::TENSOR_TYPE_I32, + data: cast_slice(&tokens), + }; + + // Create input_pos tensor + let input_pos: Vec = (0..max_length as i32).collect(); + let input_pos_dims = &[1u32, max_length as u32]; + let input_pos_tensor = wasi_nn::Tensor { + dimensions: input_pos_dims, + r#type: wasi_nn::TENSOR_TYPE_I32, + data: cast_slice(&input_pos), + }; + + // Create kv tensor (ensure kv_data has the correct size) + let kv_data = vec![0.0_f32; max_length]; // Example initialization + let kv_dims = &[32u32, 2u32, 1u32, 16u32, 10u32, 64u32]; + let kv_tensor = wasi_nn::Tensor { + dimensions: kv_dims, + r#type: wasi_nn::TENSOR_TYPE_F32, + data: cast_slice(&kv_data), + }; + + // Set inputs + unsafe { + wasi_nn::set_input(ctx, 0, tokens_tensor)?; + wasi_nn::set_input(ctx, 1, input_pos_tensor)?; + wasi_nn::set_input(ctx, 2, kv_tensor)?; + } + + // Run inference + run_inference(&ctx)?; + // Get output + let output = get_model_output(&ctx, 0, 655360)?; + + // Detokenize output + let response = tokenizer.detokenize(&output.as_slice()); + + // Display response + writeln!(stdout, "Bot: {}", response)?; + } + + println!("Chatbot session ended."); + Ok(()) +} + +// Function to run inference +fn run_inference(ctx: &wasi_nn::GraphExecutionContext) -> Result<()> { + unsafe { Ok(wasi_nn::compute(*ctx)?) } +} + +// Function to get model output +fn get_model_output( + ctx: &wasi_nn::GraphExecutionContext, + index: u32, + size: usize, +) -> Result> { + let mut buffer = vec![0i32; size]; + let buffer_ptr = cast_slice_mut(&mut buffer).as_mut_ptr(); + let byte_len = (size * std::mem::size_of::()) as u32; + unsafe { wasi_nn::get_output(*ctx, index, buffer_ptr, byte_len)? }; + Ok(buffer) +}