diff --git a/moonshine/demo/live_captions.py b/moonshine/demo/live_captions.py index e8dcf39..91d14c6 100644 --- a/moonshine/demo/live_captions.py +++ b/moonshine/demo/live_captions.py @@ -23,15 +23,12 @@ CHUNK_SIZE = 512 # Silero VAD requirement with sampling rate 16000. LOOKBACK_CHUNKS = 5 -MARKER_LENGTH = 6 MAX_LINE_LENGTH = 80 # These affect live caption updating - adjust for your platform speed and model. MAX_SPEECH_SECS = 15 MIN_REFRESH_SECS = 0.2 -VERBOSE = False - class Transcriber(object): def __init__(self, model_name, rate=16000): @@ -39,9 +36,10 @@ def __init__(self, model_name, rate=16000): raise ValueError("Moonshine supports sampling rate 16000 Hz.") self.model = MoonshineOnnxModel(model_name=model_name) self.rate = rate - assets_dir = f"{os.path.join(os.path.dirname(__file__), '..', 'assets')}" - tokenizer_file = f"{assets_dir}{os.sep}tokenizer.json" - self.tokenizer = Tokenizer.from_file(str(tokenizer_file)) + tokenizer_path = os.path.join( + MOONSHINE_DEMO_DIR, "..", "assets", "tokenizer.json" + ) + self.tokenizer = Tokenizer.from_file(tokenizer_path) self.inference_secs = 0 self.number_inferences = 0 @@ -72,32 +70,35 @@ def input_callback(data, frames, time, status): return input_callback -def end_recording(speech, marker=""): +def end_recording(speech): """Transcribes, caches and prints the caption. Clears speech buffer.""" - if len(marker) != MARKER_LENGTH: - raise ValueError("Unexpected marker length.") text = transcribe(speech) - caption_cache.append(text + " " + marker) - print_captions(text + (" " + marker) if VERBOSE else "", True) + caption_cache.append(text) + print_captions(text, new_cached_caption=True) speech *= 0.0 def print_captions(text, new_cached_caption=False): """Prints right justified on same line, prepending cached captions.""" print("\r" + " " * MAX_LINE_LENGTH, end="", flush=True) - if len(text) > MAX_LINE_LENGTH: - text = text[-MAX_LINE_LENGTH:] - elif text != "\n": + if len(text) < MAX_LINE_LENGTH: for caption in caption_cache[::-1]: - text = (caption[:-MARKER_LENGTH] if not VERBOSE else caption + " ") + text + text = caption + " " + text if len(text) > MAX_LINE_LENGTH: break - if len(text) > MAX_LINE_LENGTH: - text = text[-MAX_LINE_LENGTH:] + if len(text) > MAX_LINE_LENGTH: + text = text[-MAX_LINE_LENGTH:] text = " " * (MAX_LINE_LENGTH - len(text)) + text print("\r" + text, end="", flush=True) +def soft_reset(vad_iterator): + """Soft resets Silero VADIterator without affecting VAD model state.""" + vad_iterator.triggered = False + vad_iterator.temp_end = 0 + vad_iterator.current_sample = 0 + + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="live_captions", @@ -145,7 +146,7 @@ def print_captions(text, new_cached_caption=False): try: while True: chunk, status = q.get() - if VERBOSE and status: + if status: print(status) speech = np.concatenate((speech, chunk)) @@ -160,18 +161,15 @@ def print_captions(text, new_cached_caption=False): if "end" in speech_dict and recording: recording = False - end_recording(speech, "") + end_recording(speech) elif recording: # Possible speech truncation can cause hallucination. if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS: recording = False - end_recording(speech, "") - # Soft reset without affecting VAD model state. - vad_iterator.triggered = False - vad_iterator.temp_end = 0 - vad_iterator.current_sample = 0 + end_recording(speech) + soft_reset(vad_iterator) if (time.time() - start_time) > MIN_REFRESH_SECS: print_captions(transcribe(speech)) @@ -184,7 +182,7 @@ def print_captions(text, new_cached_caption=False): while not q.empty(): chunk, _ = q.get() speech = np.concatenate((speech, chunk)) - end_recording(speech, "") + end_recording(speech) print(f""" @@ -196,7 +194,4 @@ def print_captions(text, new_cached_caption=False): model realtime factor : {(transcribe.speech_secs / transcribe.inference_secs):0.2f}x """) if caption_cache: - print("Cached captions.") - for caption in caption_cache: - print(caption[:-MARKER_LENGTH], end="", flush=True) - print("") + print(f"Cached captions.\n{' '.join(caption_cache)}")