Skip to content

Commit

Permalink
[Example] Give an example of llama tflite (#165)
Browse files Browse the repository at this point in the history
* add llama

* add llama 2

* fix encoding

* fix error

* Update README.md

* Update Cargo.toml
  • Loading branch information
victoryang00 authored Dec 24, 2024
1 parent d75c738 commit 56d497a
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 4 deletions.
2 changes: 1 addition & 1 deletion openvino-mobilenet-image/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ publish = false

[dependencies]
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
wasi-nn = { version = "0.4.0" }
wasi-nn = { version = "0.6.0" }

[workspace]
2 changes: 1 addition & 1 deletion openvino-mobilenet-raw/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ edition = "2021"
publish = false

[dependencies]
wasi-nn = { version = "0.4.0" }
wasi-nn = { version = "0.6.0" }

[workspace]
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ name = "openvino-road-seg-adas"
version = "0.2.0"

[dependencies]
wasi-nn = "0.4.0"
wasi-nn = "0.6.0"
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
2 changes: 1 addition & 1 deletion tflite-birds_v1-image/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ publish = false

[dependencies]
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
wasi-nn = "0.4.0"
wasi-nn = "0.6.0"

[workspace]
59 changes: 59 additions & 0 deletions wasmedge-tf-llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Llama Example For WasmEdge-Tensorflow plug-in

This package is a high-level Rust bindings for [WasmEdge-TensorFlow plug-in](https://wasmedge.org/docs/develop/rust/tensorflow) example of Mobilenet.

## Dependencies

This crate depends on the `wasmedge_tensorflow_interface` in the `Cargo.toml`:

```toml
[dependencies]
wasmedge_tensorflow_interface = "0.3.0"
wasi-nn = "0.1.0" # Ensure you use the latest version
thiserror = "1.0"
bytemuck = "1.13.1"
log = "0.4.19"
env_logger = "0.10.0"
anyhow = "1.0.79"
```

## Build

Compile the application to WebAssembly:

```bash
cd rust && cargo build --target=wasm32-wasip1 --release
```

The output WASM file will be at [`rust/target/wasm32-wasip1/release/wasmedge-tf-example-llama.wasm`](wasmedge-tf-example-llama.wasm).
To speed up the image processing, we can enable the AOT mode in WasmEdge with:

```bash
wasmedge compile rust/target/wasm32-wasi/release/wasmedge-tf-example-llama.wasm wasmedge-tf-example-llama_aot.wasm
```

## Run

The frozen `tflite` model should be translated through `ai_edge_torch` and HuggingFace.

### Execute

Users should [install the WasmEdge with WasmEdge-TensorFlow and WasmEdge-Image plug-ins](https://wasmedge.org/docs/start/install#wasmedge-tensorflow-plug-in).

```bash
curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- --plugins wasmedge_tensorflow wasmedge_image
```

Execute the WASM with the `wasmedge` with Tensorflow Lite supporting:

```bash
wasmedge --dir .:. wasmedge-tf-example-llama.wasm ./llama_1b_q8_ekv1280.tflite
```

You will get the output:

```console
Input the Chatbot:
Hello world
Hello world!
```
17 changes: 17 additions & 0 deletions wasmedge-tf-llama/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "wasmedge-tf-example-llama"
version = "0.1.0"
authors = ["victoryang00"]
readme = "README.md"
edition = "2021"
publish = false

[dependencies]
wasmedge_tensorflow_interface = "0.3.0"
wasi-nn = "0.1.0" # Ensure you use the latest version
thiserror = "1.0"
bytemuck = "1.13.1"
log = "0.4.19"
env_logger = "0.10.0"
anyhow = "1.0.79"
[workspace]
227 changes: 227 additions & 0 deletions wasmedge-tf-llama/rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#![feature(str_as_str)]
use bytemuck::{cast_slice, cast_slice_mut};
use std::collections::HashMap;
use std::env;
use std::fs::File;
use std::io::Read;
use std::io::{self, BufRead, Write};
use std::process;
use thiserror::Error;
use wasi_nn;

// Define a custom error type to handle multiple error sources
#[derive(Error, Debug)]
pub enum ChatbotError {
#[error("IO Error: {0}")]
Io(#[from] std::io::Error),

#[error("WASI NN Error: {0}")]
WasiNn(String),

#[error("Other Error: {0}")]
Other(String),
}

impl From<wasi_nn::Error> for ChatbotError {
fn from(error: wasi_nn::Error) -> Self {
ChatbotError::WasiNn(error.to_string())
}
}

type Result<T> = std::result::Result<T, ChatbotError>;

// Tokenizer Struct
struct Tokenizer {
vocab: HashMap<String, i32>,
vocab_reverse: HashMap<i32, String>,
next_id: i32,
}

impl Tokenizer {
fn new(initial_vocab: Vec<(&str, i32)>) -> Self {
let mut vocab = HashMap::new();
let mut vocab_reverse = HashMap::new();
let mut next_id = 1;

for (word, id) in initial_vocab {
vocab.insert(word.to_string(), id);
vocab_reverse.insert(id, word.to_string());
if id >= next_id {
next_id = id + 1;
}
}

// Add special tokens
vocab.insert("<UNK>".to_string(), 0);
vocab_reverse.insert(0, "<UNK>".to_string());
vocab.insert("<PAD>".to_string(), -1);
vocab_reverse.insert(-1, "<PAD>".to_string());

Tokenizer {
vocab,
vocab_reverse,
next_id,
}
}

fn tokenize(&mut self, input: &str) -> Vec<i32> {
input
.split_whitespace()
.map(|word| {
self.vocab.get(word).cloned().unwrap_or_else(|| {
let id = self.next_id;
self.vocab.insert(word.to_string(), id);
self.vocab_reverse.insert(id, word.to_string());
self.next_id += 1;
id
})
})
.collect()
}

fn tokenize_with_fixed_length(&mut self, input: &str, max_length: usize) -> Vec<i32> {
let mut tokens = self.tokenize(input);

if tokens.len() > max_length {
tokens.truncate(max_length);
} else if tokens.len() < max_length {
tokens.extend(vec![-1; max_length - tokens.len()]); // Assuming -1 is the <PAD> token
}

tokens
}

fn detokenize(&self, tokens: &[i32]) -> String {
tokens
.iter()
.map(|&token| self.vocab_reverse.get(&token).map_or("<UNK>", |v| v).as_str())
.collect::<Vec<&str>>()
.join(" ")
}
}

// Function to load the TFLite model
fn load_model(model_path: &str) -> Result<wasi_nn::Graph> {
let mut file = File::open(model_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let model_segments = &[&buffer[..]];
Ok(unsafe { wasi_nn::load(model_segments, 4, wasi_nn::EXECUTION_TARGET_CPU)? })
}

// Function to initialize the execution context
fn init_context(graph: wasi_nn::Graph) -> Result<wasi_nn::GraphExecutionContext> {
Ok(unsafe { wasi_nn::init_execution_context(graph)? })
}

fn main() -> Result<()> {
// Parse command-line arguments
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <model_file>", args[0]);
process::exit(1);
}
let model_file = &args[1];

// Load the model
let graph = load_model(model_file)?;

// Initialize execution context
let ctx = init_context(graph)?;

let mut stdout = io::stdout();
let stdin = io::stdin();

// Initialize KV cache data
println!("Chatbot is ready! Type your messages below:");

for line in stdin.lock().lines() {
let user_input = line?;
if user_input.trim().is_empty() {
continue;
}
if user_input.to_lowercase() == "exit" {
break;
}
// Initialize tokenizer
let initial_vocab = vec![
("hello", 1),
("world", 2),
("this", 3),
("is", 4),
("a", 5),
("test", 6),
("<PAD>", -1),
];

let mut tokenizer = Tokenizer::new(initial_vocab);
// let user_input = "hello world this is a test with more words";

// Tokenize with fixed length
let max_length = 655360;
let tokens = tokenizer.tokenize_with_fixed_length(&user_input, max_length);
let tokens_dims = &[1u32, max_length as u32];
let tokens_tensor = wasi_nn::Tensor {
dimensions: tokens_dims,
r#type: wasi_nn::TENSOR_TYPE_I32,
data: cast_slice(&tokens),
};

// Create input_pos tensor
let input_pos: Vec<i32> = (0..max_length as i32).collect();
let input_pos_dims = &[1u32, max_length as u32];
let input_pos_tensor = wasi_nn::Tensor {
dimensions: input_pos_dims,
r#type: wasi_nn::TENSOR_TYPE_I32,
data: cast_slice(&input_pos),
};

// Create kv tensor (ensure kv_data has the correct size)
let kv_data = vec![0.0_f32; max_length]; // Example initialization
let kv_dims = &[32u32, 2u32, 1u32, 16u32, 10u32, 64u32];
let kv_tensor = wasi_nn::Tensor {
dimensions: kv_dims,
r#type: wasi_nn::TENSOR_TYPE_F32,
data: cast_slice(&kv_data),
};

// Set inputs
unsafe {
wasi_nn::set_input(ctx, 0, tokens_tensor)?;
wasi_nn::set_input(ctx, 1, input_pos_tensor)?;
wasi_nn::set_input(ctx, 2, kv_tensor)?;
}

// Run inference
run_inference(&ctx)?;
// Get output
let output = get_model_output(&ctx, 0, 655360)?;

// Detokenize output
let response = tokenizer.detokenize(&output.as_slice());

// Display response
writeln!(stdout, "Bot: {}", response)?;
}

println!("Chatbot session ended.");
Ok(())
}

// Function to run inference
fn run_inference(ctx: &wasi_nn::GraphExecutionContext) -> Result<()> {
unsafe { Ok(wasi_nn::compute(*ctx)?) }
}

// Function to get model output
fn get_model_output(
ctx: &wasi_nn::GraphExecutionContext,
index: u32,
size: usize,
) -> Result<Vec<i32>> {
let mut buffer = vec![0i32; size];
let buffer_ptr = cast_slice_mut(&mut buffer).as_mut_ptr();
let byte_len = (size * std::mem::size_of::<i32>()) as u32;
unsafe { wasi_nn::get_output(*ctx, index, buffer_ptr, byte_len)? };
Ok(buffer)
}

0 comments on commit 56d497a

Please sign in to comment.