Skip to content

Commit

Permalink
Refactor live captions demo
Browse files Browse the repository at this point in the history
Fix printing on exit to mitigate some cached captions not printed.
Reduce verbosity.
  • Loading branch information
guynich committed Nov 1, 2024
1 parent 1d983bf commit e1fdeec
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions moonshine/demo/live_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@

CHUNK_SIZE = 512 # Silero VAD requirement with sampling rate 16000.
LOOKBACK_CHUNKS = 5
MARKER_LENGTH = 6
MAX_LINE_LENGTH = 80

# These affect live caption updating - adjust for your platform speed and model.
MAX_SPEECH_SECS = 15
MIN_REFRESH_SECS = 0.2

VERBOSE = False


class Transcriber(object):
def __init__(self, model_name, rate=16000):
if rate != 16000:
raise ValueError("Moonshine supports sampling rate 16000 Hz.")
self.model = MoonshineOnnxModel(model_name=model_name)
self.rate = rate
assets_dir = f"{os.path.join(os.path.dirname(__file__), '..', 'assets')}"
tokenizer_file = f"{assets_dir}{os.sep}tokenizer.json"
self.tokenizer = Tokenizer.from_file(str(tokenizer_file))
tokenizer_path = (
os.path.join(MOONSHINE_DEMO_DIR, '..', 'assets', 'tokenizer.json')
)
self.tokenizer = Tokenizer.from_file(tokenizer_path)

self.inference_secs = 0
self.number_inferences = 0
Expand Down Expand Up @@ -72,32 +70,35 @@ def input_callback(data, frames, time, status):
return input_callback


def end_recording(speech, marker=""):
def end_recording(speech):
"""Transcribes, caches and prints the caption. Clears speech buffer."""
if len(marker) != MARKER_LENGTH:
raise ValueError("Unexpected marker length.")
text = transcribe(speech)
caption_cache.append(text + " " + marker)
print_captions(text + (" " + marker) if VERBOSE else "", True)
caption_cache.append(text)
print_captions(text, new_cached_caption=True)
speech *= 0.0


def print_captions(text, new_cached_caption=False):
"""Prints right justified on same line, prepending cached captions."""
print("\r" + " " * MAX_LINE_LENGTH, end="", flush=True)
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
elif text != "\n":
if len(text) < MAX_LINE_LENGTH:
for caption in caption_cache[::-1]:
text = (caption[:-MARKER_LENGTH] if not VERBOSE else caption + " ") + text
text = caption + " " + text
if len(text) > MAX_LINE_LENGTH:
break
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
text = " " * (MAX_LINE_LENGTH - len(text)) + text
print("\r" + text, end="", flush=True)


def soft_reset(vad_iterator):
"""Soft resets Silero VADIterator without affecting VAD model state."""
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="live_captions",
Expand Down Expand Up @@ -145,7 +146,7 @@ def print_captions(text, new_cached_caption=False):
try:
while True:
chunk, status = q.get()
if VERBOSE and status:
if status:
print(status)

speech = np.concatenate((speech, chunk))
Expand All @@ -160,18 +161,15 @@ def print_captions(text, new_cached_caption=False):

if "end" in speech_dict and recording:
recording = False
end_recording(speech, "<STOP>")
end_recording(speech)

elif recording:
# Possible speech truncation can cause hallucination.

if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
recording = False
end_recording(speech, "<SNIP>")
# Soft reset without affecting VAD model state.
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
end_recording(speech)
soft_reset(vad_iterator)

if (time.time() - start_time) > MIN_REFRESH_SECS:
print_captions(transcribe(speech))
Expand All @@ -184,7 +182,7 @@ def print_captions(text, new_cached_caption=False):
while not q.empty():
chunk, _ = q.get()
speech = np.concatenate((speech, chunk))
end_recording(speech, "<END.>")
end_recording(speech)

print(f"""
Expand All @@ -196,7 +194,4 @@ def print_captions(text, new_cached_caption=False):
model realtime factor : {(transcribe.speech_secs / transcribe.inference_secs):0.2f}x
""")
if caption_cache:
print("Cached captions.")
for caption in caption_cache:
print(caption[:-MARKER_LENGTH], end="", flush=True)
print("")
print(f"Cached captions.\n{' '.join(caption_cache)}")

0 comments on commit e1fdeec

Please sign in to comment.