diff --git a/README.md b/README.md index 3629c0c..96b49f3 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,12 @@ export KERAS_BACKEND=jax # Use useful-moonshine[jax-cuda] for jax on GPU ``` +To run with ONNX runtime that is supported on platforms, run the following: + +```shell +uv pip install useful-moonshine[onnx]@git+https://github.com/usefulsensors/moonshine.git +``` + ### 3. Try it out You can test Moonshine by transcribing the provided example audio file with the `.transcribe` function: @@ -100,6 +106,7 @@ python ``` The first argument is a path to an audio file and the second is the name of a Moonshine model. `moonshine/tiny` and `moonshine/base` are the currently available models. +Use the `moonshine.transcribe_with_onnx` function to use the ONNX runtime for inference. The parameters are the same as they are for `moonshine.transcribe`. ## TODO * [ ] Live transcription demo diff --git a/moonshine/__init__.py b/moonshine/__init__.py index c49adf9..ec7a840 100644 --- a/moonshine/__init__.py +++ b/moonshine/__init__.py @@ -4,4 +4,11 @@ ASSETS_DIR = Path(__file__).parents[0] / "assets" from .model import load_model, Moonshine -from .transcribe import transcribe, benchmark, load_tokenizer, load_audio +from .transcribe import ( + transcribe, + benchmark, + load_tokenizer, + load_audio, + transcribe_with_onnx, +) +from .onnx_model import MoonshineOnnxModel diff --git a/moonshine/onnx_model.py b/moonshine/onnx_model.py new file mode 100644 index 0000000..4ce6609 --- /dev/null +++ b/moonshine/onnx_model.py @@ -0,0 +1,70 @@ +def _get_onnx_weights(model_name): + from huggingface_hub import hf_hub_download + + repo = "UsefulSensors/moonshine" + + return ( + hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}") + for x in ("preprocess", "encode", "uncached_decode", "cached_decode") + ) + + +class MoonshineOnnxModel(object): + def __init__(self, models_dir=None, model_name=None): + import onnxruntime + + if models_dir is None: + assert ( + model_name is not None + ), "model_name should be specified if models_dir is not" + preprocess, encode, uncached_decode, cached_decode = ( + self._load_weights_from_hf_hub(model_name) + ) + else: + preprocess, encode, uncached_decode, cached_decode = [ + f"{models_dir}/{x}.onnx" + for x in ["preprocess", "encode", "uncached_decode", "cached_decode"] + ] + self.preprocess = onnxruntime.InferenceSession(preprocess) + self.encode = onnxruntime.InferenceSession(encode) + self.uncached_decode = onnxruntime.InferenceSession(uncached_decode) + self.cached_decode = onnxruntime.InferenceSession(cached_decode) + + def _load_weights_from_hf_hub(self, model_name): + model_name = model_name.split("/")[-1] + return _get_onnx_weights(model_name) + + def generate(self, audio, max_len=None): + "audio has to be a numpy array of shape [1, num_audio_samples]" + if max_len is None: + # max 6 tokens per second of audio + max_len = int((audio.shape[-1] / 16_000) * 6) + preprocessed = self.preprocess.run([], dict(args_0=audio))[0] + seq_len = [preprocessed.shape[-2]] + + context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0] + inputs = [[1]] + seq_len = [1] + + tokens = [1] + logits, *cache = self.uncached_decode.run( + [], dict(args_0=inputs, args_1=context, args_2=seq_len) + ) + for i in range(max_len): + next_token = logits.squeeze().argmax() + tokens.extend([next_token]) + if next_token == 2: + break + + seq_len[0] += 1 + inputs = [[next_token]] + logits, *cache = self.cached_decode.run( + [], + dict( + args_0=inputs, + args_1=context, + args_2=seq_len, + **{f"args_{i+3}": x for i, x in enumerate(cache)}, + ), + ) + return [tokens] diff --git a/moonshine/tools/onnx_model.py b/moonshine/tools/onnx_model.py deleted file mode 100644 index 2f82b0c..0000000 --- a/moonshine/tools/onnx_model.py +++ /dev/null @@ -1,50 +0,0 @@ -import onnxruntime -import moonshine - - -class MoonshineOnnxModel(object): - def __init__(self, models_dir): - self.preprocess = onnxruntime.InferenceSession(f"{models_dir}/preprocess.onnx") - self.encode = onnxruntime.InferenceSession(f"{models_dir}/encode.onnx") - self.uncached_decode = onnxruntime.InferenceSession( - f"{models_dir}/uncached_decode.onnx" - ) - self.cached_decode = onnxruntime.InferenceSession( - f"{models_dir}/cached_decode.onnx" - ) - self.tokenizer = moonshine.load_tokenizer() - - def generate(self, audio, max_len=None): - audio = moonshine.load_audio(audio, return_numpy=True) - if max_len is None: - # max 6 tokens per second of audio - max_len = int((audio.shape[-1] / 16_000) * 6) - preprocessed = self.preprocess.run([], dict(args_0=audio))[0] - seq_len = [preprocessed.shape[-2]] - - context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0] - inputs = [[1]] - seq_len = [1] - - tokens = [1] - logits, *cache = self.uncached_decode.run( - [], dict(args_0=inputs, args_1=context, args_2=seq_len) - ) - for i in range(max_len): - next_token = logits.squeeze().argmax() - tokens.extend([next_token]) - if next_token == 2: - break - - seq_len[0] += 1 - inputs = [[next_token]] - logits, *cache = self.cached_decode.run( - [], - dict( - args_0=inputs, - args_1=context, - args_2=seq_len, - **{f"args_{i+3}": x for i, x in enumerate(cache)}, - ), - ) - return [tokens] diff --git a/moonshine/transcribe.py b/moonshine/transcribe.py index 786b6e6..2c1197c 100644 --- a/moonshine/transcribe.py +++ b/moonshine/transcribe.py @@ -40,6 +40,21 @@ def transcribe(audio, model="moonshine/base"): return load_tokenizer().decode_batch(tokens) +def transcribe_with_onnx(audio, model="moonshine/base"): + from .onnx_model import MoonshineOnnxModel + + if isinstance(model, str): + model = MoonshineOnnxModel(model_name=model) + assert isinstance( + model, MoonshineOnnxModel + ), f"Expected a MoonshineOnnxModel model or a model name, not a {type(model)}" + audio = load_audio(audio, return_numpy=True) + assert_audio_size(audio) + + tokens = model.generate(audio) + return load_tokenizer().decode_batch(tokens) + + def load_tokenizer(): tokenizer_file = ASSETS_DIR / "tokenizer.json" tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_file)) diff --git a/setup.py b/setup.py index d3181c9..14b0265 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ def read_version(fname="moonshine/version.py"): "tensorflow": ["tensorflow==2.17.0"], "jax": ["jax==0.4.34", "keras==3.6.0"], "jax-cuda": ["jax[cuda12]", "keras==3.6.0"], + "onnx": ["onnxruntime>=1.19.2"], }, include_package_data=True, )