Skip to content

Commit

Permalink
Promote ONNX model to top namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
keveman committed Oct 24, 2024
1 parent 888f729 commit 76de868
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 51 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ export KERAS_BACKEND=jax
# Use useful-moonshine[jax-cuda] for jax on GPU
```

To run with ONNX runtime that is supported on platforms, run the following:

```shell
uv pip install useful-moonshine[onnx]@git+https://github.com/usefulsensors/moonshine.git
```

### 3. Try it out

You can test Moonshine by transcribing the provided example audio file with the `.transcribe` function:
Expand All @@ -100,6 +106,7 @@ python
```

The first argument is a path to an audio file and the second is the name of a Moonshine model. `moonshine/tiny` and `moonshine/base` are the currently available models.
Use the `moonshine.transcribe_with_onnx` function to use the ONNX runtime for inference. The parameters are the same as they are for `moonshine.transcribe`.

## TODO
* [ ] Live transcription demo
Expand Down
9 changes: 8 additions & 1 deletion moonshine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
ASSETS_DIR = Path(__file__).parents[0] / "assets"

from .model import load_model, Moonshine
from .transcribe import transcribe, benchmark, load_tokenizer, load_audio
from .transcribe import (
transcribe,
benchmark,
load_tokenizer,
load_audio,
transcribe_with_onnx,
)
from .onnx_model import MoonshineOnnxModel
70 changes: 70 additions & 0 deletions moonshine/onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
def _get_onnx_weights(model_name):
from huggingface_hub import hf_hub_download

repo = "UsefulSensors/moonshine"

return (
hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}")
for x in ("preprocess", "encode", "uncached_decode", "cached_decode")
)


class MoonshineOnnxModel(object):
def __init__(self, models_dir=None, model_name=None):
import onnxruntime

if models_dir is None:
assert (
model_name is not None
), "model_name should be specified if models_dir is not"
preprocess, encode, uncached_decode, cached_decode = (
self._load_weights_from_hf_hub(model_name)
)
else:
preprocess, encode, uncached_decode, cached_decode = [
f"{models_dir}/{x}.onnx"
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"]
]
self.preprocess = onnxruntime.InferenceSession(preprocess)
self.encode = onnxruntime.InferenceSession(encode)
self.uncached_decode = onnxruntime.InferenceSession(uncached_decode)
self.cached_decode = onnxruntime.InferenceSession(cached_decode)

def _load_weights_from_hf_hub(self, model_name):
model_name = model_name.split("/")[-1]
return _get_onnx_weights(model_name)

def generate(self, audio, max_len=None):
"audio has to be a numpy array of shape [1, num_audio_samples]"
if max_len is None:
# max 6 tokens per second of audio
max_len = int((audio.shape[-1] / 16_000) * 6)
preprocessed = self.preprocess.run([], dict(args_0=audio))[0]
seq_len = [preprocessed.shape[-2]]

context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0]
inputs = [[1]]
seq_len = [1]

tokens = [1]
logits, *cache = self.uncached_decode.run(
[], dict(args_0=inputs, args_1=context, args_2=seq_len)
)
for i in range(max_len):
next_token = logits.squeeze().argmax()
tokens.extend([next_token])
if next_token == 2:
break

seq_len[0] += 1
inputs = [[next_token]]
logits, *cache = self.cached_decode.run(
[],
dict(
args_0=inputs,
args_1=context,
args_2=seq_len,
**{f"args_{i+3}": x for i, x in enumerate(cache)},
),
)
return [tokens]
50 changes: 0 additions & 50 deletions moonshine/tools/onnx_model.py

This file was deleted.

15 changes: 15 additions & 0 deletions moonshine/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ def transcribe(audio, model="moonshine/base"):
return load_tokenizer().decode_batch(tokens)


def transcribe_with_onnx(audio, model="moonshine/base"):
from .onnx_model import MoonshineOnnxModel

if isinstance(model, str):
model = MoonshineOnnxModel(model_name=model)
assert isinstance(
model, MoonshineOnnxModel
), f"Expected a MoonshineOnnxModel model or a model name, not a {type(model)}"
audio = load_audio(audio, return_numpy=True)
assert_audio_size(audio)

tokens = model.generate(audio)
return load_tokenizer().decode_batch(tokens)


def load_tokenizer():
tokenizer_file = ASSETS_DIR / "tokenizer.json"
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_file))
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read_version(fname="moonshine/version.py"):
"tensorflow": ["tensorflow==2.17.0"],
"jax": ["jax==0.4.34", "keras==3.6.0"],
"jax-cuda": ["jax[cuda12]", "keras==3.6.0"],
"onnx": ["onnxruntime>=1.19.2"],
},
include_package_data=True,
)

0 comments on commit 76de868

Please sign in to comment.