Skip to content

Commit

Permalink
wasi-nn: remove BackendKind, add wrapper structs (#6893)
Browse files Browse the repository at this point in the history
* wasi-nn: accept a list of backends instead of a hash map

* wasi-nn: refactor with new `Backend` and `Registry` wrappers

The wasi-nn crate uses virtual dispatch to handle different kinds of
backends and registries. For easier use, these items are now wrapped in
their own `Box<dyn ...>` structs.

* wasi-nn: completely remove `BackendKind`

This change completely removes the `BackendKind` enum, which uniquely
listed the kinds of backends that the wasi-nn crate contained. Since we
allow many kinds of backends in `WasiNnCtx` which can be implemented
even outside the crate, this `BackendKind` enum no longer makes sense.
This change uses `GraphEncoding` instead.

* nit: move `pub mod backend` with other exports

prtest:full

* fix: remove broken doc links
  • Loading branch information
abrown authored Aug 25, 2023
1 parent f8c03d5 commit ac7d070
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 73 deletions.
44 changes: 11 additions & 33 deletions crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
//! implementations to maintain backend-specific state between calls.
mod openvino;
pub mod openvino;

use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use crate::{ExecutionContext, Graph};
use std::{error::Error, fmt, path::Path, str::FromStr};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{Backend, ExecutionContext, Graph};
use std::path::Path;
use thiserror::Error;
use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
pub fn list() -> Vec<crate::Backend> {
vec![Backend::from(OpenvinoBackend::default())]
}

/// A [Backend] contains the necessary state to load [Graph]s.
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
pub trait BackendInner: Send + Sync {
fn encoding(&self) -> GraphEncoding;
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
}

/// Some [Backend]s support loading a [Graph] from a directory on the
/// filesystem; this is not a general requirement for backends but is useful for
/// the Wasmtime CLI.
pub trait BackendFromDir: Backend {
pub trait BackendFromDir: BackendInner {
fn load_from_dir(
&mut self,
builders: &Path,
Expand All @@ -35,13 +35,13 @@ pub trait BackendFromDir: Backend {
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
/// implementation for the user-facing graph.
pub trait BackendGraph: Send + Sync {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
}

/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
/// backing implementation for a user-facing execution context.
pub trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
Expand All @@ -61,25 +61,3 @@ pub enum BackendError {
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(usize),
}

#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)]
pub enum BackendKind {
OpenVINO,
}
impl FromStr for BackendKind {
type Err = BackendKindParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(BackendKind::OpenVINO),
_ => Err(BackendKindParseError(s.into())),
}
}
}
#[derive(Debug)]
pub struct BackendKindParseError(String);
impl fmt::Display for BackendKindParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown backend: {}", self.0)
}
}
impl Error for BackendKindParseError {}
14 changes: 7 additions & 7 deletions crates/wasi-nn/src/backend/openvino.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
//! Implements a `wasi-nn` [`Backend`] using OpenVINO.
//! Implements a `wasi-nn` [`BackendInner`] using OpenVINO.
use super::{Backend, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::{Arc, Mutex};
use std::{fs::File, io::Read, path::Path};

#[derive(Default)]
pub(crate) struct OpenvinoBackend(Option<openvino::Core>);
pub struct OpenvinoBackend(Option<openvino::Core>);
unsafe impl Send for OpenvinoBackend {}
unsafe impl Sync for OpenvinoBackend {}

impl Backend for OpenvinoBackend {
fn name(&self) -> &str {
"openvino"
impl BackendInner for OpenvinoBackend {
fn encoding(&self) -> GraphEncoding {
GraphEncoding::Openvino
}

fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
Expand Down
24 changes: 13 additions & 11 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].
use crate::backend::{Backend, BackendError, BackendKind};
use crate::backend::{self, BackendError};
use crate::wit::types::GraphEncoding;
use crate::{ExecutionContext, Graph, GraphRegistry, InMemoryRegistry};
use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry};
use anyhow::anyhow;
use std::{collections::HashMap, hash::Hash, path::Path};
use thiserror::Error;
use wiggle::GuestError;

type Backends = HashMap<BackendKind, Box<dyn Backend>>;
type Registry = Box<dyn GraphRegistry>;
type GraphId = u32;
type GraphExecutionContextId = u32;
type BackendName = String;
Expand All @@ -21,31 +19,34 @@ type GraphDirectory = String;
/// model types.
pub fn preload(
preload_graphs: &[(BackendName, GraphDirectory)],
) -> anyhow::Result<(Backends, Registry)> {
let mut backends: HashMap<_, _> = crate::backend::list().into_iter().collect();
) -> anyhow::Result<(impl IntoIterator<Item = Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.get_mut(&kind.parse()?)
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Box::new(registry)))
Ok((backends, Registry::from(registry)))
}

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: Backends,
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: Backends, registry: Registry) -> Self {
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self {
backends,
registry,
Expand Down Expand Up @@ -129,6 +130,7 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::registry::GraphRegistry;

#[test]
fn example() {
Expand All @@ -139,6 +141,6 @@ mod test {
}
}

let ctx = WasiNnCtx::new(HashMap::new(), Box::new(FakeRegistry));
let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry));
}
}
43 changes: 42 additions & 1 deletion crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
mod backend;
mod ctx;
mod registry;

pub mod backend;
pub use ctx::{preload, WasiNnCtx};
pub use registry::{GraphRegistry, InMemoryRegistry};
pub mod wit;
pub mod witx;

use std::sync::Arc;

/// A machine learning backend.
pub struct Backend(Box<dyn backend::BackendInner>);
impl std::ops::Deref for Backend {
type Target = dyn backend::BackendInner;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Backend {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
impl<T: backend::BackendInner + 'static> From<T> for Backend {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}

/// A backend-defined graph (i.e., ML model).
#[derive(Clone)]
pub struct Graph(Arc<dyn backend::BackendGraph>);
Expand Down Expand Up @@ -42,3 +61,25 @@ impl std::ops::DerefMut for ExecutionContext {
self.0.as_mut()
}
}

/// A container for graphs.
pub struct Registry(Box<dyn GraphRegistry>);
impl std::ops::Deref for Registry {
type Target = dyn GraphRegistry;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Registry {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
impl<T> From<T> for Registry
where
T: GraphRegistry + 'static,
{
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
39 changes: 29 additions & 10 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
//! computation to a [`Backend`]
//! 3. convert some types
//!
//! [`Backend`]: crate::backend::Backend
//! [`Backend`]: crate::Backend
//! [`types`]: crate::wit::types
use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx};
use crate::{ctx::UsageError, WasiNnCtx};
use std::{error::Error, fmt, hash::Hash, str::FromStr};

/// Generate the traits and types from the `wasi-nn` WIT specification.
mod gen_ {
Expand All @@ -40,8 +41,7 @@ impl gen::graph::Host for WasiNnCtx {
encoding: gen::graph::GraphEncoding,
target: gen::graph::ExecutionTarget,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
let backend_kind: BackendKind = encoding.try_into()?;
let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) {
let graph = if let Some(backend) = self.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
backend.load(&slices, target.into())?
} else {
Expand Down Expand Up @@ -137,12 +137,31 @@ impl gen::errors::Host for WasiNnCtx {}

impl gen::tensor::Host for WasiNnCtx {}

impl TryFrom<gen::graph::GraphEncoding> for crate::backend::BackendKind {
type Error = UsageError;
fn try_from(value: gen::graph::GraphEncoding) -> Result<Self, Self::Error> {
match value {
gen::graph::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
_ => Err(UsageError::InvalidEncoding(value.into())),
impl Hash for gen::graph::GraphEncoding {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
}
}

impl FromStr for gen::graph::GraphEncoding {
type Err = GraphEncodingParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(gen::graph::GraphEncoding::Openvino),
"onnx" => Ok(gen::graph::GraphEncoding::Onnx),
"pytorch" => Ok(gen::graph::GraphEncoding::Pytorch),
"tensorflow" => Ok(gen::graph::GraphEncoding::Tensorflow),
"tensorflowlite" => Ok(gen::graph::GraphEncoding::Tensorflowlite),
"autodetect" => Ok(gen::graph::GraphEncoding::Autodetect),
_ => Err(GraphEncodingParseError(s.into())),
}
}
}
#[derive(Debug)]
pub struct GraphEncodingParseError(String);
impl fmt::Display for GraphEncodingParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown graph encoding: {}", self.0)
}
}
impl Error for GraphEncodingParseError {}
13 changes: 2 additions & 11 deletions crates/wasi-nn/src/witx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
encoding: gen::types::GraphEncoding,
target: gen::types::ExecutionTarget,
) -> Result<gen::types::Graph> {
let graph = if let Some(backend) = self.backends.get_mut(&encoding.try_into()?) {
let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) {
// Retrieve all of the "builder lists" from the Wasm memory (see
// $graph_builder_array) as slices for a backend to operate on.
let mut slices = vec![];
Expand Down Expand Up @@ -149,15 +149,6 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {

// Implement some conversion from `witx::types::*` to this crate's version.

impl TryFrom<gen::types::GraphEncoding> for crate::backend::BackendKind {
type Error = UsageError;
fn try_from(value: gen::types::GraphEncoding) -> std::result::Result<Self, Self::Error> {
match value {
gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
_ => Err(UsageError::InvalidEncoding(value.into())),
}
}
}
impl From<gen::types::ExecutionTarget> for crate::wit::types::ExecutionTarget {
fn from(value: gen::types::ExecutionTarget) -> Self {
match value {
Expand All @@ -177,7 +168,7 @@ impl From<gen::types::GraphEncoding> for crate::wit::types::GraphEncoding {
gen::types::GraphEncoding::Tensorflowlite => {
crate::wit::types::GraphEncoding::Tensorflowlite
}
gen::types::GraphEncoding::Autodetect => todo!("autodetect not supported"),
gen::types::GraphEncoding::Autodetect => crate::wit::types::GraphEncoding::Autodetect,
}
}
}
Expand Down

0 comments on commit ac7d070

Please sign in to comment.