Skip to content

Commit

Permalink
fix: custom verifier type hints
Browse files Browse the repository at this point in the history
fix: be sure to close pickle file after opening
  • Loading branch information
mikejgray committed Jun 9, 2024
1 parent c40fe92 commit e5113c2
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions openwakeword/custom_verifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@
# limitations under the License.

# Imports
import os
from tqdm import tqdm
import collections
import openwakeword
import numpy as np
import scipy
import os
import pickle
from typing import List, Union

import numpy as np
import scipy
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer, StandardScaler
from tqdm import tqdm

import openwakeword


# Define functions to prepare data for speaker dependent verifier model
Expand Down Expand Up @@ -112,8 +114,8 @@ def train_verifier_model(features: np.ndarray, labels: np.ndarray):


def train_custom_verifier(
positive_reference_clips: str,
negative_reference_clips: str,
positive_reference_clips: List[Union[str, os.PathLike]],
negative_reference_clips: List[Union[str, os.PathLike]],
output_path: str,
model_name: str,
**kwargs
Expand All @@ -123,11 +125,11 @@ def train_custom_verifier(
from a single user.
Args:
positive_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
positive_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files
of the target wake word/phrase.
negative_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
negative_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files
of miscellaneous speech not containing the target wake word/phrase.
output_path (str): The location to save the trained verifier model (as a scikit-learn .joblib file)
output_path (str): The location to save the trained verifier model (as a Python pickle file (.pkl))
model_name (str): The name or path of the trained openWakeWord model that the verifier model will be
based on. If only a name, it must be one of the pre-trained models included in the
openWakeWord release.
Expand Down Expand Up @@ -171,4 +173,5 @@ def train_custom_verifier(

# Save logistic regression model to specified output location
print("Done!")
pickle.dump(lr_model, open(output_path, "wb"))
with open(output_path, "wb") as f:
pickle.dump(lr_model, f)

0 comments on commit e5113c2

Please sign in to comment.