Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Promote ONNX model to top namespace #31

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ export KERAS_BACKEND=jax
# Use useful-moonshine[jax-cuda] for jax on GPU
```

To run with ONNX runtime that is supported on platforms, run the following:

```shell
uv pip install useful-moonshine[onnx]@git+https://github.com/usefulsensors/moonshine.git
```

### 3. Try it out

You can test Moonshine by transcribing the provided example audio file with the `.transcribe` function:
Expand All @@ -100,6 +106,7 @@ python
```

The first argument is a path to an audio file and the second is the name of a Moonshine model. `moonshine/tiny` and `moonshine/base` are the currently available models.
Use the `moonshine.transcribe_with_onnx` function to use the ONNX runtime for inference. The parameters are the same as they are for `moonshine.transcribe`.

## TODO
* [ ] Live transcription demo
Expand Down
9 changes: 8 additions & 1 deletion moonshine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
ASSETS_DIR = Path(__file__).parents[0] / "assets"

from .model import load_model, Moonshine
from .transcribe import transcribe, benchmark, load_tokenizer, load_audio
from .transcribe import (
transcribe,
benchmark,
load_tokenizer,
load_audio,
transcribe_with_onnx,
)
from .onnx_model import MoonshineOnnxModel
70 changes: 70 additions & 0 deletions moonshine/onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
def _get_onnx_weights(model_name):
from huggingface_hub import hf_hub_download

repo = "UsefulSensors/moonshine"

return (
hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}")
for x in ("preprocess", "encode", "uncached_decode", "cached_decode")
)


class MoonshineOnnxModel(object):
def __init__(self, models_dir=None, model_name=None):
import onnxruntime

if models_dir is None:
assert (
model_name is not None
), "model_name should be specified if models_dir is not"
preprocess, encode, uncached_decode, cached_decode = (
self._load_weights_from_hf_hub(model_name)
)
else:
preprocess, encode, uncached_decode, cached_decode = [
f"{models_dir}/{x}.onnx"
for x in ["preprocess", "encode", "uncached_decode", "cached_decode"]
]
self.preprocess = onnxruntime.InferenceSession(preprocess)
self.encode = onnxruntime.InferenceSession(encode)
self.uncached_decode = onnxruntime.InferenceSession(uncached_decode)
self.cached_decode = onnxruntime.InferenceSession(cached_decode)

def _load_weights_from_hf_hub(self, model_name):
model_name = model_name.split("/")[-1]
return _get_onnx_weights(model_name)

def generate(self, audio, max_len=None):
"audio has to be a numpy array of shape [1, num_audio_samples]"
if max_len is None:
# max 6 tokens per second of audio
max_len = int((audio.shape[-1] / 16_000) * 6)
preprocessed = self.preprocess.run([], dict(args_0=audio))[0]
seq_len = [preprocessed.shape[-2]]

context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0]
inputs = [[1]]
seq_len = [1]

tokens = [1]
logits, *cache = self.uncached_decode.run(
[], dict(args_0=inputs, args_1=context, args_2=seq_len)
)
for i in range(max_len):
next_token = logits.squeeze().argmax()
tokens.extend([next_token])
if next_token == 2:
break

seq_len[0] += 1
inputs = [[next_token]]
logits, *cache = self.cached_decode.run(
[],
dict(
args_0=inputs,
args_1=context,
args_2=seq_len,
**{f"args_{i+3}": x for i, x in enumerate(cache)},
),
)
return [tokens]
50 changes: 0 additions & 50 deletions moonshine/tools/onnx_model.py

This file was deleted.

15 changes: 15 additions & 0 deletions moonshine/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ def transcribe(audio, model="moonshine/base"):
return load_tokenizer().decode_batch(tokens)


def transcribe_with_onnx(audio, model="moonshine/base"):
from .onnx_model import MoonshineOnnxModel

if isinstance(model, str):
model = MoonshineOnnxModel(model_name=model)
assert isinstance(
model, MoonshineOnnxModel
), f"Expected a MoonshineOnnxModel model or a model name, not a {type(model)}"
audio = load_audio(audio, return_numpy=True)
assert_audio_size(audio)

tokens = model.generate(audio)
return load_tokenizer().decode_batch(tokens)


def load_tokenizer():
tokenizer_file = ASSETS_DIR / "tokenizer.json"
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_file))
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read_version(fname="moonshine/version.py"):
"tensorflow": ["tensorflow==2.17.0"],
"jax": ["jax==0.4.34", "keras==3.6.0"],
"jax-cuda": ["jax[cuda12]", "keras==3.6.0"],
"onnx": ["onnxruntime>=1.19.2"],
},
include_package_data=True,
)