From 74839d5ca2da5703151d246a68a81563ffe43c34 Mon Sep 17 00:00:00 2001 From: dscripka Date: Sun, 8 Oct 2023 21:03:21 -0400 Subject: [PATCH] Added model download utility functions and updated model metadata for official models --- openwakeword/__init__.py | 35 +++++++++++++++++------ openwakeword/data.py | 4 +-- openwakeword/model.py | 4 +-- openwakeword/utils.py | 62 ++++++++++++++++++++++++++++++++++++++++ setup.py | 3 +- 5 files changed, 94 insertions(+), 14 deletions(-) diff --git a/openwakeword/__init__.py b/openwakeword/__init__.py index d49e9e2..b74f8bc 100755 --- a/openwakeword/__init__.py +++ b/openwakeword/__init__.py @@ -5,24 +5,41 @@ __all__ = ['Model', 'VAD', 'train_custom_verifier'] -models = { +FEATURE_MODELS = { + "embedding": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/embedding_model.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite" + }, + "melspectrogram": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/melspectrogram_model.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram_model.tflite" + }, +} + +MODELS = { "alexa": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/alexa_v0.1.tflite" }, "hey_mycroft": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_mycroft_v0.1.tflite" }, "hey_jarvis": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_jarvis_v0.1.tflite" }, "hey_rhasspy": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_rhasspy_v0.1.tflite" }, "timer": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/timer_v0.1.tflite" }, "weather": { - "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite") + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/weather_v0.1.tflite" } } @@ -40,6 +57,6 @@ def get_pretrained_model_paths(inference_framework="tflite"): if inference_framework == "tflite": - return [models[i]["model_path"] for i in models.keys()] + return [MODELS[i]["model_path"] for i in MODELS.keys()] elif inference_framework == "onnx": - return [models[i]["model_path"].replace(".tflite", ".onnx") for i in models.keys()] + return [MODELS[i]["model_path"].replace(".tflite", ".onnx") for i in MODELS.keys()] diff --git a/openwakeword/data.py b/openwakeword/data.py index 7c34549..b5db3c2 100755 --- a/openwakeword/data.py +++ b/openwakeword/data.py @@ -445,7 +445,7 @@ def mix_clips_batch( # Apply volume augmentation if volume_augmentation: volume_levels = np.random.uniform(0.02, 1.0, mixed_clips_batch.shape[0]) - mixed_clips_batch = (volume_levels/mixed_clips_batch.max(axis=1)[0])[..., None]*mixed_clips_batch + mixed_clips_batch = (volume_levels/mixed_clips_batch.max(dim=1)[0])[..., None]*mixed_clips_batch else: # Normalize clips only if max value is outside of [-1, 1] abs_max, _ = torch.max( @@ -457,7 +457,7 @@ def mix_clips_batch( mixed_clips_batch = (mixed_clips_batch.numpy()*32767).astype(np.int16) # Remove any clips that are silent (happens rarely when mixing/reverberating) - error_index = np.where(mixed_clips_batch.max(axis=1) != 0)[0] + error_index = torch.from_numpy(np.where(mixed_clips_batch.max(dim=1) != 0)[0]) mixed_clips_batch = mixed_clips_batch[error_index] labels_batch = labels_batch[error_index] sequence_labels_batch = sequence_labels_batch[error_index] diff --git a/openwakeword/model.py b/openwakeword/model.py index 46f603a..6ae820c 100755 --- a/openwakeword/model.py +++ b/openwakeword/model.py @@ -67,7 +67,7 @@ def __init__( with VAD scores above the threshold will be returned. The default value (0), disables voice activity detection entirely. custom_verifier_models (dict): A dictionary of paths to custom verifier models, where - the keys are the model names (corresponding to the openwakeword.models + the keys are the model names (corresponding to the openwakeword.MODELS attribute) and the values are the filepaths of the custom verifier models. custom_verifier_threshold (float): The score threshold to use a custom verifier model. If the score @@ -85,7 +85,7 @@ def __init__( wakeword_model_names = [] if wakeword_models == []: wakeword_models = pretrained_model_paths - wakeword_model_names = list(openwakeword.models.keys()) + wakeword_model_names = list(openwakeword.MODELS.keys()) elif len(wakeword_models) >= 1: for ndx, i in enumerate(wakeword_models): if os.path.exists(i): diff --git a/openwakeword/utils.py b/openwakeword/utils.py index c4f9b15..4aa932d 100644 --- a/openwakeword/utils.py +++ b/openwakeword/utils.py @@ -23,6 +23,8 @@ import logging import openwakeword from typing import Union, List, Callable, Deque +import requests +from tqdm import tqdm # Base class for computing audio features using Google's speech_embedding @@ -526,6 +528,66 @@ def f(clips): return {list(i.keys())[0]: list(i.values())[0] for i in results} +# Function to download files from a URL with a progress bar +def download_file(url, target_directory, file_size=None): + """A simpel function to download a file from a URL with a progress bar using only the requests library""" + local_filename = url.split('/')[-1] + + with requests.get(url, stream=True) as r: + if file_size is not None: + progress_bar = tqdm(total=file_size, unit='iB', unit_scale=True, desc=f"{local_filename}") + else: + total_size = int(r.headers.get('content-length', 0)) + progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"{local_filename}") + + with open(local_filename, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + +# Function to download models from GitHub release assets +def download_models( + model_names: List[str] = [], + target_directory: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "resources", "models") + ): + """ + Download the specified models from the release assets in the openWakeWord GitHub repository. + Uses the official urls in the MODELS dictionary in openwakeword/__init__.py. + + Args: + model_names (List[str]): The names of the models to download (e.g., hey_jarvis_v0.1). Both ONNX and + tflite models will be downloaded. If not provided (the default), + the latest versions of all models will be downloaded. + target_directory (str): The directory to save the models to. Defaults to the install location + of openWakeWord (i.e., the `resources/models` directory). + Returns: + None + """ + + # Always download melspectrogram and embedding models, if they don't already exist + for feature_model in openwakeword.FEATURE_MODELS.values(): + if not os.path.exists(os.path.join(target_directory, feature_model["download_url"].split("/")[-1])): + download_file(feature_model["download_url"], target_directory) + download_file(feature_model["download_url"].replace(".tflite", ".onnx"), target_directory) + + # Get all model urls + official_model_urls = [i["download_url"] for i in openwakeword.MODELS.values()] + official_model_names = [i["download_url"].split("/")[-1] for i in openwakeword.MODELS.values()] + + if model_names != []: + for model_name in model_names: + url = [i for i, j in zip(official_model_urls, official_model_names) if model_name in j] + if url != []: + download_file(url[0], target_directory) + else: + for official_model_url in official_model_urls: + download_file(official_model_url, target_directory) + download_file(official_model_url.replace(".tflite", ".onnx"), target_directory) + + # Handle deprecated arguments and naming (thanks to https://stackoverflow.com/a/74564394) def re_arg(kwarg_map): def decorator(func): diff --git a/setup.py b/setup.py index 1c7aecb..1773a99 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,8 @@ def build_additional_requires(): 'tflite-runtime>=2.8.0,<3; platform_system == "Linux"', 'tqdm>=4.0,<5.0', 'scipy>=1.3,<2', - 'scikit-learn>=1,<2' + 'scikit-learn>=1,<2', + 'requests>=2.0,<3', ], extras_require={ 'test': [