diff --git a/openwakeword/custom_verifier_model.py b/openwakeword/custom_verifier_model.py index 8846624..1c8c55f 100644 --- a/openwakeword/custom_verifier_model.py +++ b/openwakeword/custom_verifier_model.py @@ -13,17 +13,19 @@ # limitations under the License. # Imports -import os -from tqdm import tqdm import collections -import openwakeword -import numpy as np -import scipy +import os import pickle +from typing import List, Union +import numpy as np +import scipy from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer, StandardScaler +from tqdm import tqdm + +import openwakeword # Define functions to prepare data for speaker dependent verifier model @@ -112,8 +114,8 @@ def train_verifier_model(features: np.ndarray, labels: np.ndarray): def train_custom_verifier( - positive_reference_clips: str, - negative_reference_clips: str, + positive_reference_clips: List[Union[str, os.PathLike]], + negative_reference_clips: List[Union[str, os.PathLike]], output_path: str, model_name: str, **kwargs @@ -123,11 +125,11 @@ def train_custom_verifier( from a single user. Args: - positive_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files + positive_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files of the target wake word/phrase. - negative_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files + negative_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files of miscellaneous speech not containing the target wake word/phrase. - output_path (str): The location to save the trained verifier model (as a scikit-learn .joblib file) + output_path (str): The location to save the trained verifier model (as a Python pickle file (.pkl)) model_name (str): The name or path of the trained openWakeWord model that the verifier model will be based on. If only a name, it must be one of the pre-trained models included in the openWakeWord release. @@ -171,4 +173,5 @@ def train_custom_verifier( # Save logistic regression model to specified output location print("Done!") - pickle.dump(lr_model, open(output_path, "wb")) + with open(output_path, "wb") as f: + pickle.dump(lr_model, f)