Skip to content

Commit

Permalink
[WASI-NN] Add metadata support (WasmEdge#2957)
Browse files Browse the repository at this point in the history
- Use `setInput()` with index = 1 for llama.cpp options
- Encode options with JSON string

Signed-off-by: dm4 <[email protected]>
  • Loading branch information
dm4 authored Oct 20, 2023
1 parent 8b4a00d commit cea4cbd
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 67 deletions.
59 changes: 58 additions & 1 deletion plugins/wasi_nn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,63 @@ else()
set(LLAMA_METAL OFF)
endif()

# simdjson for ggml backend
find_package(simdjson QUIET)
if(simdjson_FOUND)
message(STATUS "SIMDJSON found")
else()
message(STATUS "Downloading SIMDJSON source")
include(FetchContent)
FetchContent_Declare(
simdjson
GIT_REPOSITORY https://github.com/simdjson/simdjson.git
GIT_TAG tags/v3.2.1
GIT_SHALLOW TRUE)

if(MSVC)
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
get_property(
compile_options
DIRECTORY
PROPERTY COMPILE_OPTIONS
)
set_property(
DIRECTORY
APPEND
PROPERTY COMPILE_OPTIONS
-Wno-undef
-Wno-suggest-override
-Wno-documentation
-Wno-sign-conversion
-Wno-extra-semi-stmt
-Wno-old-style-cast
-Wno-error=unused-parameter
-Wno-error=unused-template
-Wno-conditional-uninitialized
-Wno-implicit-int-conversion
-Wno-shorten-64-to-32
-Wno-range-loop-bind-reference
-Wno-format-nonliteral
-Wno-unused-exception-parameter
-Wno-unused-member-function
)
unset(compile_options)
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set_property(
DIRECTORY
APPEND
PROPERTY COMPILE_OPTIONS
/wd4100 # unreferenced formal parameter
)
endif()
endif()

set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON)
FetchContent_MakeAvailable(simdjson)

message(STATUS "Downloading SIMDJSON source -- done")
endif()

add_subdirectory(thirdparty)

wasmedge_add_library(wasmedgePluginWasiNN
Expand Down Expand Up @@ -69,7 +126,7 @@ endif()

string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND)
if(BACKEND STREQUAL "ggml")
target_link_libraries(wasmedgePluginWasiNN PRIVATE llama)
target_link_libraries(wasmedgePluginWasiNN PRIVATE llama simdjson)
endif()

include(WASINNDeps)
Expand Down
142 changes: 76 additions & 66 deletions plugins/wasi_nn/ggml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,14 @@
#include "wasinnenv.h"

#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML
#include "simdjson.h"
#include <common.h>
#include <cstdlib>
#include <llama.h>
#endif

namespace WasmEdge::Host::WASINN::GGML {
#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML
ErrNo wasmedge_llama_context_params(llama_context_params &Params) noexcept {
const char *LlamaNContextEnv = std::getenv("LLAMA_N_CTX");
const char *LlamaLogEnv = std::getenv("LLAMA_LOG");
if (LlamaNContextEnv != nullptr) {
try {
Params.n_ctx = std::stoi(LlamaNContextEnv);
} catch (const std::out_of_range &e) {
spdlog::error(
"[WASI-NN] GGML backend: set n_ctx failed: out_of_range {}"sv,
e.what());
return ErrNo::InvalidArgument;
} catch (const std::invalid_argument &e) {
spdlog::error(
"[WASI-NN] GGML backend: set n_ctx failed: invalid_argument {}"sv,
e.what());
return ErrNo::InvalidArgument;
}
if (LlamaLogEnv != nullptr) {
spdlog::info("[WASI-NN] GGML backend: set n_ctx to {}"sv, Params.n_ctx);
}
}

return ErrNo::Success;
}

Expect<ErrNo> load(WasiNNEnvironment &Env, Span<const Span<uint8_t>> Builders,
[[maybe_unused]] Device Device, uint32_t &GraphId) noexcept {
// The graph builder length must be 1.
Expand Down Expand Up @@ -77,12 +53,6 @@ Expect<ErrNo> load(WasiNNEnvironment &Env, Span<const Span<uint8_t>> Builders,
// Initialize ggml model.
gpt_params Params;
llama_backend_init(Params.numa);
llama_context_params ContextParams = llama_context_default_params();
ErrNo Err = wasmedge_llama_context_params(ContextParams);
if (Err != ErrNo::Success) {
spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv);
return ErrNo::InvalidArgument;
}
llama_model_params ModelParams = llama_model_default_params();
GraphRef.LlamaModel =
llama_load_model_from_file(ModelFilePath.c_str(), ModelParams);
Expand All @@ -104,24 +74,86 @@ Expect<ErrNo> load(WasiNNEnvironment &Env, Span<const Span<uint8_t>> Builders,
Expect<ErrNo> initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId,
uint32_t &ContextId) noexcept {
Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]);

ContextId = Env.NNContext.size() - 1;

// Set the default context options.
auto &CxtRef = Env.NNContext[ContextId].get<Context>();
auto ContextDefault = llama_context_default_params();
CxtRef.EnableLog = false;
CxtRef.StreamStdout = false;
CxtRef.CtxSize = ContextDefault.n_ctx;
CxtRef.NPredict = ContextDefault.n_ctx;
CxtRef.NGPULayers = 0;

return ErrNo::Success;
}

Expect<ErrNo> setInput(WasiNNEnvironment &Env, uint32_t ContextId,
[[maybe_unused]] uint32_t Index,
const TensorData &Tensor) noexcept {
uint32_t Index, const TensorData &Tensor) noexcept {
auto &CxtRef = Env.NNContext[ContextId].get<Context>();
auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get<Graph>();

// Use index 1 for metadata.
if (Index == 1) {
// Decode metadata.
std::string Metadata(reinterpret_cast<char *>(Tensor.Tensor.data()),
Tensor.Tensor.size());
simdjson::dom::parser Parser;
simdjson::dom::element Doc;
auto ParseError = Parser.parse(Metadata).get(Doc);
if (ParseError) {
spdlog::error("[WASI-NN] GGML backend: Parse metadata error"sv);
return ErrNo::InvalidEncoding;
}

// Get metadata from the json.
if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) {
auto Err = Doc["enable-log"].get<bool>().get(CxtRef.EnableLog);
if (Err) {
spdlog::error(
"[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv);
return ErrNo::InvalidArgument;
}
}
if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) {
auto Err = Doc["stream-stdout"].get<bool>().get(CxtRef.StreamStdout);
if (Err) {
spdlog::error(
"[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv);
return ErrNo::InvalidArgument;
}
}
if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) {
auto Err = Doc["ctx-size"].get<uint64_t>().get(CxtRef.CtxSize);
if (Err) {
spdlog::error(
"[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv);
return ErrNo::InvalidArgument;
}
}
if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) {
auto Err = Doc["n-predict"].get<uint64_t>().get(CxtRef.NPredict);
if (Err) {
spdlog::error(
"[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv);
return ErrNo::InvalidArgument;
}
}
if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) {
auto Err = Doc["n-gpu-layers"].get<uint64_t>().get(CxtRef.NGPULayers);
if (Err) {
spdlog::error(
"[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv);
return ErrNo::InvalidArgument;
}
}

return ErrNo::Success;
}

// Initialize the llama context.
llama_context_params ContextParams = llama_context_default_params();
ErrNo Err = wasmedge_llama_context_params(ContextParams);
if (Err != ErrNo::Success) {
spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv);
return ErrNo::InvalidArgument;
}
ContextParams.n_ctx = CxtRef.CtxSize;
GraphRef.LlamaContext =
llama_new_context_with_model(GraphRef.LlamaModel, ContextParams);

Expand Down Expand Up @@ -160,9 +192,7 @@ Expect<ErrNo> compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept {
return ErrNo::InvalidArgument;
}

// Use env LLAMA_LOG=1 to enable llama log.
const char *LlamaLogEnv = std::getenv("LLAMA_LOG");
if (LlamaLogEnv != nullptr) {
if (CxtRef.EnableLog) {
spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv,
llama_print_system_info());
}
Expand All @@ -176,26 +206,7 @@ Expect<ErrNo> compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept {
const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext);
// NPredict is the number of tokens to predict. Same as -n, --n-predict in
// llama.cpp.
int NPredict = MaxContextSize;
const char *LlamaNPredictEnv = std::getenv("LLAMA_N_PREDICT");
if (LlamaNPredictEnv != nullptr) {
try {
NPredict = std::stoi(LlamaNPredictEnv);
} catch (const std::out_of_range &e) {
spdlog::error(
"[WASI-NN] GGML backend: set n_predict failed: out_of_range {}"sv,
e.what());
return ErrNo::InvalidArgument;
} catch (const std::invalid_argument &e) {
spdlog::error(
"[WASI-NN] GGML backend: set n_predict failed: invalid_argument {}"sv,
e.what());
return ErrNo::InvalidArgument;
}
if (LlamaLogEnv != nullptr) {
spdlog::info("[WASI-NN] GGML backend: set n_predict to {}"sv, NPredict);
}
}
int NPredict = CxtRef.NPredict;

// Evaluate the initial prompt.
llama_batch LlamaBatch = llama_batch_init(NPredict, 0);
Expand Down Expand Up @@ -242,9 +253,8 @@ Expect<ErrNo> compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept {
std::string NextToken =
llama_token_to_piece(GraphRef.LlamaContext, NewTokenId);

// When setting STREAM_TO_STDOUT, we print the output to stdout.
const char *StreamOutput = std::getenv("STREAM_TO_STDOUT");
if (StreamOutput != nullptr) {
// When setting StreamStdout, we print the output to stdout.
if (CxtRef.StreamStdout) {
std::cout << NextToken << std::flush;
}

Expand All @@ -269,7 +279,7 @@ Expect<ErrNo> compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept {
}
}

if (LlamaLogEnv != nullptr) {
if (CxtRef.EnableLog) {
llama_print_timings(GraphRef.LlamaContext);
}

Expand Down
5 changes: 5 additions & 0 deletions plugins/wasi_nn/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ struct Context {
size_t GraphId;
std::vector<llama_token> LlamaInputs;
std::string LlamaOutputs;
bool EnableLog;
bool StreamStdout;
uint64_t CtxSize;
uint64_t NPredict;
uint64_t NGPULayers;
};
#else
struct Graph {};
Expand Down

0 comments on commit cea4cbd

Please sign in to comment.