Skip to content

Commit

Permalink
[Example] Add Qwen example (#137)
Browse files Browse the repository at this point in the history
Signed-off-by: jokemanfire <[email protected]>
  • Loading branch information
jokemanfire authored May 10, 2024
1 parent f0762dd commit feee750
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
10 changes: 10 additions & 0 deletions wasmedge-ggml/qwen/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "wasmedge-ggml-qwen"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serde_json = "1.0"
wasmedge-wasi-nn = "0.7.1"
35 changes: 35 additions & 0 deletions wasmedge-ggml/qwen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# `通义千问`

## Execute - Tong Yi Qwen

### Model Download Link

```console
wget https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat-GGUF/resolve/main/qwen1_5-0_5b-chat-q2_k.gguf
```

### Execution Command

Please make sure you have the `qwen1_5-0_5b-chat-q2_k.gguf` file in the current directory.
If you want to enable GPU support, please set the `n_gpu_layers` environment variable.
You can also change the `ctx_size` to have a larger context window via `--env ctx_size=8192`. The default value is 1024.

```console
$ wasmedge --dir .:. \
--env n_gpu_layers=10 \
--nn-preload default:GGML:AUTO:qwen1_5-0_5b-chat-q2_k.gguf \
wasmedge-ggml-qwen.wasm default

USER:
你好
ASSISTANT:
你好!有什么我能帮你的吗?
USER:
你是谁
ASSISTANT:
我是一个人工智能助手,我叫通义千问。有什么我可以帮助你的吗?
USER:
能帮助我写Rust代码吗?
ASSISTANT:
当然可以!我可以帮助你使用Rust语言编写代码,帮助你理解和编写出高质量的代码。我可以帮你编写函数、函数和类,提供使用Python解释器编译和运行代码的建议,提供使用Node.js、Python或Java的代码示例,以及更多关于如何使用Python、Java或C++的代码示例。
```
123 changes: 123 additions & 0 deletions wasmedge-ggml/qwen/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use serde_json::json;
use serde_json::Value;
use std::env;
use std::io;
use wasmedge_wasi_nn::{
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
TensorType,
};

fn read_input() -> String {
loop {
let mut answer = String::new();
io::stdin()
.read_line(&mut answer)
.expect("Failed to read line");
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
return answer.trim().to_string();
}
}
}

fn get_options_from_env() -> Value {
let mut options = json!({});
if let Ok(val) = env::var("enable_log") {
options["enable-log"] = serde_json::from_str(val.as_str())
.expect("invalid value for enable-log option (true/false)")
} else {
options["enable-log"] = serde_json::from_str("false").unwrap()
}
if let Ok(val) = env::var("n_gpu_layers") {
options["n-gpu-layers"] =
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
} else {
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
}
options["ctx-size"] = serde_json::from_str("1024").unwrap();

options
}

fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
context.set_input(0, TensorType::U8, &[1], &data)
}

fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
// Preserve for 4096 tokens with average token length 6
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
let mut output_size = context
.get_output(index, &mut output_buffer)
.expect("Failed to get output");
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);

String::from_utf8((output_buffer[..output_size]).to_vec())
.unwrap()
.to_string()
}

fn get_output_from_context(context: &GraphExecutionContext) -> String {
get_data_from_context(context, 0)
}

fn main() {
let args: Vec<String> = env::args().collect();
let model_name: &str = &args[1];

// Set options for the graph. Check our README for more details:
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
let options = get_options_from_env();

// Create graph and initialize context.
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
.build_from_cache(model_name)
.expect("Failed to build graph");
let mut context = graph
.init_execution_context()
.expect("Failed to init context");

let mut saved_prompt =
String::from("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n");

loop {
println!("USER:");
let input = read_input();

saved_prompt = format!("{}\n<|im_start|>user\n{}<|im_end|>\n", saved_prompt, input);
let contex_prompt = saved_prompt.clone() + "<|im_start|>assistant\n";

// Set prompt to the input tensor.
set_data_to_context(&mut context, contex_prompt.as_bytes().to_vec())
.expect("Failed to set input");

// Execute the inference.
let mut reset_prompt = false;
match context.compute() {
Ok(_) => (),
Err(Error::BackendError(BackendError::ContextFull)) => {
println!("\n[INFO] Context full, we'll reset the context and continue.");
reset_prompt = true;
}
Err(Error::BackendError(BackendError::PromptTooLong)) => {
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
reset_prompt = true;
}
Err(err) => {
println!("\n[ERROR] {}", err);
}
}

// Retrieve the output.
let mut output = get_output_from_context(&context);
println!("ASSISTANT:\n{}", output.trim());

// Update the saved prompt.
if reset_prompt {
saved_prompt.clear();
} else {
output = output.trim().to_string();
saved_prompt = format!("{} {}", saved_prompt, output);
}
}
}

0 comments on commit feee750

Please sign in to comment.