Skip to content

Commit

Permalink
finish nn context
Browse files Browse the repository at this point in the history
  • Loading branch information
victoryang00 committed Dec 22, 2024
1 parent 3f04350 commit 2500852
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 82 deletions.
1 change: 1 addition & 0 deletions include/wamr_interp_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct WAMRInterpFrame {
/* Instruction pointer of the bytecode array. */
uint32 ip{};
int32 function_index{};
std::string function_name{};
/* Operand stack top pointer of the current frame. The bottom of
the stack is the next cell after the last local variable. */
uint32 sp{};
Expand Down
26 changes: 6 additions & 20 deletions include/wamr_wasi_nn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,17 @@
/* Maximum number of graph execution context per WASM instance*/
#define MAX_GRAPH_EXEC_CONTEXTS_PER_INST 10

struct WAMRWASINNGraph {
std::vector<uint8_t> buffer;
};

struct WAMRWASINNInterpreter {
// std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<uint8_t> interpreter;
uint32_t a; // placeholder
};

struct WAMRWASINNModel {
// std::unique_ptr<tflite::FlatBufferModel> model;
std::vector<uint8_t> model;
execution_target target;
std::string model_name;
std::vector<uint8_t> input_tensor;
std::vector<uint32_t> dims;
};

struct WAMRWASINNContext {
bool is_initialized;
graph_encoding current_encoding;
std::vector<WAMRWASINNGraph> graph; // TODO: support multiple graph
uint32_t current_models;
bool is_initialized = false;
graph_encoding current_encoding = graph_encoding::tensorflow;
uint32_t current_models = 0;
WAMRWASINNModel models[MAX_GRAPHS_PER_INST];
uint32_t current_interpreters;
WAMRWASINNInterpreter interpreters[MAX_GRAPH_EXEC_CONTEXTS_PER_INST];

void dump_impl(WASINNContext *env);
void restore_impl(WASINNContext *env);
};
Expand Down
83 changes: 56 additions & 27 deletions src/wamr_interp_frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,69 @@
#include <memory>
extern WAMRInstance *wamr;
void WAMRInterpFrame::dump_impl(WASMInterpFrame *env) {
SPDLOG_ERROR("not impl");
exit(-1);
if (env->function) {
wamr->set_func(env->function->u.func);

if (env->ip)
ip = env->ip - env->function->u.func->code; // here we need to get the
// offset from the code start.
#if WASM_ENABLE_FAST_INTERP == 0
if (env->sp) {
sp = reinterpret_cast<uint8 *>(env->sp) -
((uint8 *)wamr->get_exec_env()->wasm_stack.s.bottom); // offset to the wasm_stack_top
}
#endif
if (env->function->u.func->field_name)
function_name = env->function->u.func->field_name;
else
function_name = env->function->u.func_import->field_name;
}
}
std::vector<std::unique_ptr<WAMRBranchBlock>> wasm_replay_csp_bytecode(WASMExecEnv *exec_env, WASMInterpFrame *frame,
const uint8 *target_addr);
void WAMRInterpFrame::restore_impl(WASMInterpFrame *env) {
auto module_inst = (WASMModuleInstance *)wamr->get_exec_env()->module_inst;
if (0 <= function_index && function_index < module_inst->e->function_count) {
// LOGV(INFO) << fmt::format("function_index {} restored", function_index);
if (0 < function_index && function_index < module_inst->e->function_count) {
// LOGV(INFO) << fmt::format("function_index {} restored",
// function_index);
env->function = &module_inst->e->functions[function_index];
if (env->function->is_import_func) {
LOG_DEBUG("is_import_func");
exit(-1);
}
} else {
// LOGV(ERROR) << fmt::format("function_index {} invalid", function_index);
exit(-1);
}

if (env->function->is_import_func) {
SPDLOG_ERROR("is_import_func");
exit(-1);
auto target_module = wamr->get_module_instance()->e;
for (uint32 i = 0; i < target_module->function_count; i++) {
auto cur_func = &target_module->functions[i];
if (!cur_func->is_import_func) {
if (!strcmp(cur_func->u.func->field_name, function_name.c_str())) {
function_index = i;
env->function = cur_func;
break;
}
} else {
if (!strcmp(cur_func->u.func_import->field_name, function_name.c_str())) {
function_index = i;
env->function = cur_func;
break;
}
}
}
}

wamr->set_func(env->function->u.func);
auto cur_func = env->function;
WASMFunction *cur_wasm_func = cur_func->u.func;

SPDLOG_INFO("ip_offset {} sp_offset {}, code start {}", ip, sp, (void *)wasm_get_func_code(env->function));
LOG_DEBUG("ip_offset %d sp_offset %d, code start %p", ip, sp, (void *)wasm_get_func_code(env->function));
env->ip = wasm_get_func_code(env->function) + ip;

memcpy(env->lp, stack_frame.data(), stack_frame.size() * sizeof(uint32));
#if WASM_ENABLE_FAST_INTERP == 0
env->sp_bottom = env->lp + cur_func->param_cell_num + cur_func->local_cell_num;
env->sp = env->lp + sp;
env->sp_boundary = env->sp_bottom + cur_wasm_func->max_stack_cell_num;

memcpy(env->lp, stack_frame.data(), stack_frame.size() * sizeof(uint32));

// print_csps(csp);
SPDLOG_DEBUG("wasm_replay_csp_bytecode {} {} {}", (void *)wamr->get_exec_env(), (void *)env, (void *)env->ip);
LOG_DEBUG("wasm_replay_csp_bytecode %d %d %d", (void *)wamr->get_exec_env(), (void *)env, (void *)env->ip);
env->csp_bottom = (WASMBranchBlock *)env->sp_boundary;

if (env->function->u.func && !env->function->is_import_func && env->sp_bottom) {
Expand All @@ -62,15 +90,16 @@ void WAMRInterpFrame::restore_impl(WASMInterpFrame *env) {
int i = 0;
for (auto &&csp_item : csp) {
restore(csp_item.get(), env->csp_bottom + i);
SPDLOG_ERROR("csp_bottom {}", ((uint8 *)env->csp_bottom + i) - wamr->get_exec_env()->wasm_stack.s.bottom);
LOG_DEBUG("csp_bottom %d", ((uint8 *)env->csp_bottom + i) - wamr->get_exec_env()->wasm_stack.s.bottom);
i++;
}

env->csp = env->csp_bottom + csp.size();
env->csp_boundary = env->csp_bottom + env->function->u.func->max_block_num;
}
SPDLOG_INFO("func_idx {} ip {} sp {} stack bottom {}", function_index, (void *)env->ip, (void *)env->sp,
(void *)wamr->get_exec_env()->wasm_stack.s.bottom);
#endif
LOG_DEBUG("func_idx %d ip %p sp %p stack bottom %p", function_index, (void *)env->ip, (void *)env->sp,
(void *)wamr->get_exec_env()->wasm_stack.s.bottom);
}

#if WASM_ENABLE_AOT != 0
Expand All @@ -79,8 +108,8 @@ void WAMRInterpFrame::dump_impl(AOTFrame *env) {
ip = env->ip_offset;
sp = env->sp - env->lp; // offset to the wasm_stack_top

SPDLOG_INFO("function_index {} ip_offset {} lp {} sp {} sp_offset {}", env->func_index, ip, (void *)env->lp,
(void *)env->sp, sp);
LOG_DEBUG("function_index %d ip_offset %d lp %p sp %p sp_offset %lu", env->func_index, ip, (void *)env->lp,
(void *)env->sp, sp);

stack_frame = std::vector(env->lp, env->sp);

Expand All @@ -90,7 +119,7 @@ void WAMRInterpFrame::dump_impl(AOTFrame *env) {
std::cout << std::endl;
}
void WAMRInterpFrame::restore_impl(AOTFrame *env) {
SPDLOG_ERROR("not impl");
LOG_DEBUG("not impl");
exit(-1);
}

Expand Down Expand Up @@ -239,7 +268,7 @@ std::vector<std::unique_ptr<WAMRBranchBlock>> wasm_replay_csp_bytecode(WASMExecE
auto e = std::make_unique<WAMRBranchBlock>(); \
e->cell_num = cell_num; \
e->begin_addr = frame_ip - cur_func->u.func->code; \
e->target_addr = (_target_addr)-cur_func->u.func->code; \
e->target_addr = (_target_addr) - cur_func->u.func->code; \
e->frame_sp = reinterpret_cast<uint8 *>(frame_sp - (param_cell_num)) - exec_env->wasm_stack.s.bottom; \
csp.emplace_back(std::move(e)); \
}
Expand Down Expand Up @@ -840,15 +869,15 @@ std::vector<std::unique_ptr<WAMRBranchBlock>> wasm_replay_csp_bytecode(WASMExecE
case SIMD_v128_load64_splat:
case SIMD_v128_store:
/* memarg align */
skip_leb_uint32(frame_ip,frame_ip_end);
skip_leb_uint32(frame_ip, frame_ip_end);
/* memarg offset*/
skip_leb_uint32(frame_ip,frame_ip_end);
skip_leb_uint32(frame_ip, frame_ip_end);
break;

case SIMD_v128_const:
case SIMD_v8x16_shuffle:
/* immByte[16] immLaneId[16] */
CHECK_BUF1(frame_ip,frame_ip_end, 16);
CHECK_BUF1(frame_ip, frame_ip_end, 16);
frame_ip += 16;
break;

Expand Down
102 changes: 67 additions & 35 deletions src/wamr_wasi_nn_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#if WASM_ENABLE_WASI_NN != 0
#include "wamr_wasi_nn_context.h"
#include "wamr.h"
#include <wasi-nn.h>
#include "wasi_nn.h"
#include <cstdio>
#include <cstdlib>

#define MAX_MODEL_SIZE 85000000
#define MAX_OUTPUT_TENSOR_SIZE 1000000
Expand All @@ -23,50 +25,80 @@ extern WAMRInstance *wamr;
void WAMRWASINNContext::dump_impl(WASINNContext *env) {
fprintf(stderr, "dump_impl\n");
// take the model
// get target
}
void WAMRWASINNContext::restore_impl(WASINNContext *env) {
fprintf(stderr, "restore_impl\n");
// replay the graph initialization
wasm_get_output(env->wasm_instance, env->wasm_instance->output_index, env->wasm_instance->output_size);
FILE *pFile = fopen(model_name, "r");
if (pFile == NULL)
return invalid_argument;
// replay the graph initialization
if (!is_initialized) {
for (const auto &model : models) {
FILE *pFile = fopen(model.model_name.c_str(), "r");
if (pFile == nullptr)
return;

uint8_t *buffer;
size_t result;
uint8_t *buffer;
size_t result;

// allocate memory to contain the whole file:
buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
if (buffer == NULL) {
fclose(pFile);
return missing_memory;
}
// allocate memory to contain the whole file:
buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
if (buffer == nullptr) {
fclose(pFile);
return;
}

result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
if (result <= 0) {
fclose(pFile);
free(buffer);
return missing_memory;
}
result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
if (result <= 0) {
fclose(pFile);
free(buffer);
return;
}

graph_builder_array arr;
graph_builder_array arr;

arr.size = 1;
arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
if (arr.buf == NULL) {
fclose(pFile);
free(buffer);
return missing_memory;
}
arr.size = 1;
arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
if (arr.buf == nullptr) {
fclose(pFile);
free(buffer);
return;
}

arr.buf[0].size = result;
arr.buf[0].buf = buffer;
graph g;
error res = load(&arr, tensorflowlite, execution_target::gpu, &g);
if (res != error::success) {
res = load(&arr, tensorflowlite, execution_target::tpu, &g);
}
if (res != error::success) {
res = load(&arr, tensorflowlite, execution_target::cpu, &g);
}

arr.buf[0].size = result;
arr.buf[0].buf = buffer;
graph_execution_context ctx;
if (init_execution_context(g, &ctx) != success) {
return;
}
tensor_dimensions dims;
dims.size = INPUT_TENSOR_DIMS;
dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t));
if (dims.buf == NULL)
return;

error res = load(&arr, tensorflowlite, target, g);

fclose(pFile);
free(buffer);
free(arr.buf);
return res;
tensor tensor;
tensor.dimensions = &dims;
for (int i = 0; i < tensor.dimensions->size; ++i)
tensor.dimensions->buf[i] = model.dims[i];
tensor.type = fp32;
tensor.data = (uint8_t *)model.input_tensor.data();
error err = set_input(ctx, 0, &tensor);

free(dims.buf);

fclose(pFile);
free(buffer);
free(arr.buf);
}
}
}
#endif

0 comments on commit 2500852

Please sign in to comment.