From 383cad6698c59124a21261c15bd3444dff1d3b9b Mon Sep 17 00:00:00 2001 From: Muayyad alsadi Date: Mon, 3 Apr 2023 01:16:45 +0300 Subject: [PATCH] FIXES #137: make calling wasm from python 7x faster --- wasmtime/_func.py | 49 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/wasmtime/_func.py b/wasmtime/_func.py index a3ceb871..1a43d0b7 100644 --- a/wasmtime/_func.py +++ b/wasmtime/_func.py @@ -7,7 +7,8 @@ from typing import Callable, Optional, Generic, TypeVar, List, Union, Tuple, cast as cast_type, Sequence from ._exportable import AsExtern from ._store import Storelike -from ._bindings import wasmtime_val_raw_t, wasm_valtype_kind +from ._bindings import wasmtime_val_raw_t, wasm_valtype_kind, wasmtime_val_t, wasmtime_externref_t, wasmtime_func_t +from ._value import _unintern from ._ffi import ( WASMTIME_I32, WASMTIME_I64, @@ -18,6 +19,7 @@ WASMTIME_EXTERNREF, WASM_ANYREF, WASM_FUNCREF, + wasmtime_externref_data, ) @@ -40,11 +42,41 @@ def get_valtype_attr(ty: ValType): return val_id2attr[wasm_valtype_kind(ty._ptr)] +from struct import Struct + +def val_getter(store_id, val_raw, attr): + val = getattr(val_raw, attr) + + if attr=='externref': + ptr = ctypes.POINTER(wasmtime_externref_t) + if not val: return None + ffi = ptr.from_address(val) + if not ffi: return + extern_id = wasmtime_externref_data(ffi) + ret = _unintern(extern_id) + return ret + elif attr=='funcref': + if val==0: return None + f=wasmtime_func_t() + f.store_id=store_id + f.index=val + ret=Func._from_raw(f) + return ret + return val + def val_setter(dst, attr, val): if attr=='externref': - # TODO: handle None - v = Val.externref(val) - casted = ctypes.addressof(v._raw.of.externref) + if isinstance(val, Val) and val._raw.kind==WASMTIME_EXTERNREF.value: + if val._raw.of.externref: + extern_id = wasmtime_externref_data(val._raw.of.externref) + casted = ctypes.addressof(val._raw.of.externref) + else: + v = Val.externref(val) + casted = ctypes.addressof(v._raw.of.externref) + elif attr=='funcref': + if isinstance(val, Val) and val._raw.kind==WASMTIME_FUNCREF.value: + casted = val._raw.of.funcref.index + else: raise RuntimeError("foo") elif isinstance(val, Func): # TODO: handle null_funcref # TODO: validate same val._func.store_id @@ -112,9 +144,10 @@ def _extract_return(self, vals_raw: ctypes.Array[wasmtime_val_raw_t]) -> Union[I if self._results_n==0: return None if self._results_n==1: - return getattr(vals_raw[0], self._results_str0) + ret = val_getter(self._func.store_id, vals_raw[0], self._results_str0) + return ret # we can use tuple construct, but I'm using list for compatability - return [getattr(val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)] + return [val_getter(self._func.store_id, val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)] def _init_call(self, ty: FuncType): """init signature properties used by call""" @@ -123,8 +156,8 @@ def _init_call(self, ty: FuncType): ty_results = ty.results params_n = len(ty_params) results_n = len(ty_results) - self._params_str = (get_valtype_attr(i) for i in ty_params) - self._results_str = (get_valtype_attr(i) for i in ty_results) + self._params_str = [get_valtype_attr(i) for i in ty_params] + self._results_str = [get_valtype_attr(i) for i in ty_results] self._results_str0 = get_valtype_attr(ty_results[0]) if results_n else None self._params_n = params_n self._results_n = results_n