Skip to content

Commit

Permalink
Merge pull request #50 from dscripka/remove_model_files
Browse files Browse the repository at this point in the history
Remove model files
  • Loading branch information
dscripka authored Oct 10, 2023
2 parents dbd3f7a + ed90629 commit fd15e8c
Show file tree
Hide file tree
Showing 29 changed files with 130 additions and 79 deletions.
2 changes: 0 additions & 2 deletions .gitattributes

This file was deleted.

2 changes: 0 additions & 2 deletions .github/workflows/build_and_publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
with:
lfs: true
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
workflow_dispatch:

jobs:
unit_tests_linux:
Expand All @@ -18,8 +19,6 @@ jobs:

steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand All @@ -42,8 +41,6 @@ jobs:

steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand Down
2 changes: 0 additions & 2 deletions MANIFEST.in

This file was deleted.

10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,20 @@ Many thanks to [TeaPoly](https://github.com/TeaPoly/speexdsp-ns-python) for thei

# Usage

For quick local testing, clone this repository and use the included [example script](examples/detect_from_microphone.py) to try streaming detection from a local microphone. **Important note!** The model files are stored in this repo using [git-lfs](https://git-lfs.com/); make sure it is installed on your system and if needed use `git-lfs fetch --all` to make sure the the models download correctly.
For quick local testing, clone this repository and use the included [example script](examples/detect_from_microphone.py) to try streaming detection from a local microphone. You can individually download pre-trained models from current and past [releases](https://github.com/dscripka/openWakeWord/releases/), or you can download them using Python (see below).

Adding openWakeWord to your own Python code requires just a few lines:

```python
import openwakeword
from openwakeword.model import Model

# Instantiate the model
# One-time download of all pre-trained models (or only select models)
openwakeword.utils.download_models()

# Instantiate the model(s)
model = Model(
wakeword_models=["path/to/model.onnx"], # can also leave this argument empty to load all of the included pre-trained models
wakeword_models=["path/to/model.tflite"], # can also leave this argument empty to load all of the included pre-trained models
)

# Get audio data containing 16-bit 16khz PCM audio data from a file, microphone, network stream, etc.
Expand Down
42 changes: 33 additions & 9 deletions openwakeword/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,48 @@

__all__ = ['Model', 'VAD', 'train_custom_verifier']

models = {
FEATURE_MODELS = {
"embedding": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/embedding_model.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite"
},
"melspectrogram": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/melspectrogram.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.tflite"
}
}

VAD_MODELS = {
"silero_vad": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/silero_vad.onnx"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/silero_vad.onnx"
}
}

MODELS = {
"alexa": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/alexa_v0.1.tflite"
},
"hey_mycroft": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_mycroft_v0.1.tflite"
},
"hey_jarvis": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_jarvis_v0.1.tflite"
},
"hey_rhasspy": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_rhasspy_v0.1.tflite"
},
"timer": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/timer_v0.1.tflite"
},
"weather": {
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite")
"model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite"),
"download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/weather_v0.1.tflite"
}
}

Expand All @@ -40,6 +64,6 @@

def get_pretrained_model_paths(inference_framework="tflite"):
if inference_framework == "tflite":
return [models[i]["model_path"] for i in models.keys()]
return [MODELS[i]["model_path"] for i in MODELS.keys()]
elif inference_framework == "onnx":
return [models[i]["model_path"].replace(".tflite", ".onnx") for i in models.keys()]
return [MODELS[i]["model_path"].replace(".tflite", ".onnx") for i in MODELS.keys()]
4 changes: 2 additions & 2 deletions openwakeword/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def mix_clips_batch(
# Apply volume augmentation
if volume_augmentation:
volume_levels = np.random.uniform(0.02, 1.0, mixed_clips_batch.shape[0])
mixed_clips_batch = (volume_levels/mixed_clips_batch.max(axis=1)[0])[..., None]*mixed_clips_batch
mixed_clips_batch = (volume_levels/mixed_clips_batch.max(dim=1)[0])[..., None]*mixed_clips_batch
else:
# Normalize clips only if max value is outside of [-1, 1]
abs_max, _ = torch.max(
Expand All @@ -457,7 +457,7 @@ def mix_clips_batch(
mixed_clips_batch = (mixed_clips_batch.numpy()*32767).astype(np.int16)

# Remove any clips that are silent (happens rarely when mixing/reverberating)
error_index = np.where(mixed_clips_batch.max(axis=1) != 0)[0]
error_index = torch.from_numpy(np.where(mixed_clips_batch.max(dim=1) != 0)[0])
mixed_clips_batch = mixed_clips_batch[error_index]
labels_batch = labels_batch[error_index]
sequence_labels_batch = sequence_labels_batch[error_index]
Expand Down
4 changes: 2 additions & 2 deletions openwakeword/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
with VAD scores above the threshold will be returned. The default value (0),
disables voice activity detection entirely.
custom_verifier_models (dict): A dictionary of paths to custom verifier models, where
the keys are the model names (corresponding to the openwakeword.models
the keys are the model names (corresponding to the openwakeword.MODELS
attribute) and the values are the filepaths of the
custom verifier models.
custom_verifier_threshold (float): The score threshold to use a custom verifier model. If the score
Expand All @@ -85,7 +85,7 @@ def __init__(
wakeword_model_names = []
if wakeword_models == []:
wakeword_models = pretrained_model_paths
wakeword_model_names = list(openwakeword.models.keys())
wakeword_model_names = list(openwakeword.MODELS.keys())
elif len(wakeword_models) >= 1:
for ndx, i in enumerate(wakeword_models):
if os.path.exists(i):
Expand Down
3 changes: 0 additions & 3 deletions openwakeword/resources/models/alexa_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/alexa_v0.1.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/embedding_model.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/embedding_model.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_jarvis_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_jarvis_v0.1.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_mycroft_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_mycroft_v0.1.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_rhasspy_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/hey_rhasspy_v0.1.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/melspectrogram.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/melspectrogram.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/silero_vad.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/timer_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/timer_v0.1.tflite

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/weather_v0.1.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions openwakeword/resources/models/weather_v0.1.tflite

This file was deleted.

75 changes: 75 additions & 0 deletions openwakeword/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import logging
import openwakeword
from typing import Union, List, Callable, Deque
import requests
from tqdm import tqdm


# Base class for computing audio features using Google's speech_embedding
Expand Down Expand Up @@ -526,6 +528,79 @@ def f(clips):
return {list(i.keys())[0]: list(i.values())[0] for i in results}


# Function to download files from a URL with a progress bar
def download_file(url, target_directory, file_size=None):
"""A simpel function to download a file from a URL with a progress bar using only the requests library"""
local_filename = url.split('/')[-1]

with requests.get(url, stream=True) as r:
if file_size is not None:
progress_bar = tqdm(total=file_size, unit='iB', unit_scale=True, desc=f"{local_filename}")
else:
total_size = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"{local_filename}")

with open(os.path.join(target_directory, local_filename), 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))

progress_bar.close()


# Function to download models from GitHub release assets
def download_models(
model_names: List[str] = [],
target_directory: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "resources", "models")
):
"""
Download the specified models from the release assets in the openWakeWord GitHub repository.
Uses the official urls in the MODELS dictionary in openwakeword/__init__.py.
Args:
model_names (List[str]): The names of the models to download (e.g., hey_jarvis_v0.1). Both ONNX and
tflite models will be downloaded. If not provided (the default),
the latest versions of all models will be downloaded.
target_directory (str): The directory to save the models to. Defaults to the install location
of openWakeWord (i.e., the `resources/models` directory).
Returns:
None
"""
if not isinstance(model_names, list):
raise ValueError("The model_names argument must be a list of strings")

# Always download melspectrogram and embedding models, if they don't already exist
if not os.path.exists(target_directory):
os.makedirs(target_directory)
for feature_model in openwakeword.FEATURE_MODELS.values():
if not os.path.exists(os.path.join(target_directory, feature_model["download_url"].split("/")[-1])):
download_file(feature_model["download_url"], target_directory)
download_file(feature_model["download_url"].replace(".tflite", ".onnx"), target_directory)

# Always download VAD models, if they don't already exist
for vad_model in openwakeword.VAD_MODELS.values():
if not os.path.exists(os.path.join(target_directory, vad_model["download_url"].split("/")[-1])):
download_file(vad_model["download_url"], target_directory)

# Get all model urls
official_model_urls = [i["download_url"] for i in openwakeword.MODELS.values()]
official_model_names = [i["download_url"].split("/")[-1] for i in openwakeword.MODELS.values()]

if model_names != []:
for model_name in model_names:
url = [i for i, j in zip(official_model_urls, official_model_names) if model_name in j]
if url != []:
if not os.path.exists(os.path.join(target_directory, url[0].split("/")[-1])):
download_file(url[0], target_directory)
download_file(url[0].replace(".tflite", ".onnx"), target_directory)
else:
print(official_model_urls)
for official_model_url in official_model_urls:
if not os.path.exists(os.path.join(target_directory, official_model_url.split("/")[-1])):
download_file(official_model_url, target_directory)
download_file(official_model_url.replace(".tflite", ".onnx"), target_directory)


# Handle deprecated arguments and naming (thanks to https://stackoverflow.com/a/74564394)
def re_arg(kwarg_map):
def decorator(func):
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def build_additional_requires():
'tflite-runtime>=2.8.0,<3; platform_system == "Linux"',
'tqdm>=4.0,<5.0',
'scipy>=1.3,<2',
'scikit-learn>=1,<2'
'scikit-learn>=1,<2',
'requests>=2.0,<3',
],
extras_require={
'test': [
Expand All @@ -42,7 +43,8 @@ def build_additional_requires():
'flake8>=4.0,<4.1',
'pytest-mypy>=0.10.0,<1',
'mock>=5.1,<6',
'types-mock>=5.1,<6'
'types-mock>=5.1,<6',
'types-requests>=2.0,<3'
],
'full': [
'mutagen>=1.46.0,<2',
Expand Down
3 changes: 3 additions & 0 deletions tests/test_custom_verifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import tempfile
import pytest

# Download models needed for tests
openwakeword.utils.download_models(model_names=["alexa_v0.1", "hey_mycroft_v0.1"])


# Tests
class TestModels:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import tempfile
import mock

# Download models needed for tests
openwakeword.utils.download_models()


# Tests
class TestModels:
Expand Down

0 comments on commit fd15e8c

Please sign in to comment.