Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor live captions demo #49

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions moonshine/demo/live_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@

CHUNK_SIZE = 512 # Silero VAD requirement with sampling rate 16000.
LOOKBACK_CHUNKS = 5
MARKER_LENGTH = 6
MAX_LINE_LENGTH = 80

# These affect live caption updating - adjust for your platform speed and model.
MAX_SPEECH_SECS = 15
MIN_REFRESH_SECS = 0.2

VERBOSE = False


class Transcriber(object):
def __init__(self, model_name, rate=16000):
if rate != 16000:
raise ValueError("Moonshine supports sampling rate 16000 Hz.")
self.model = MoonshineOnnxModel(model_name=model_name)
self.rate = rate
assets_dir = f"{os.path.join(os.path.dirname(__file__), '..', 'assets')}"
tokenizer_file = f"{assets_dir}{os.sep}tokenizer.json"
self.tokenizer = Tokenizer.from_file(str(tokenizer_file))
tokenizer_path = os.path.join(
MOONSHINE_DEMO_DIR, "..", "assets", "tokenizer.json"
)
self.tokenizer = Tokenizer.from_file(tokenizer_path)

self.inference_secs = 0
self.number_inferences = 0
Expand Down Expand Up @@ -72,32 +70,35 @@ def input_callback(data, frames, time, status):
return input_callback


def end_recording(speech, marker=""):
def end_recording(speech):
"""Transcribes, caches and prints the caption. Clears speech buffer."""
if len(marker) != MARKER_LENGTH:
raise ValueError("Unexpected marker length.")
text = transcribe(speech)
caption_cache.append(text + " " + marker)
print_captions(text + (" " + marker) if VERBOSE else "", True)
caption_cache.append(text)
print_captions(text, new_cached_caption=True)
speech *= 0.0


def print_captions(text, new_cached_caption=False):
"""Prints right justified on same line, prepending cached captions."""
print("\r" + " " * MAX_LINE_LENGTH, end="", flush=True)
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
elif text != "\n":
if len(text) < MAX_LINE_LENGTH:
for caption in caption_cache[::-1]:
text = (caption[:-MARKER_LENGTH] if not VERBOSE else caption + " ") + text
text = caption + " " + text
if len(text) > MAX_LINE_LENGTH:
break
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
text = " " * (MAX_LINE_LENGTH - len(text)) + text
print("\r" + text, end="", flush=True)


def soft_reset(vad_iterator):
"""Soft resets Silero VADIterator without affecting VAD model state."""
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="live_captions",
Expand Down Expand Up @@ -145,7 +146,7 @@ def print_captions(text, new_cached_caption=False):
try:
while True:
chunk, status = q.get()
if VERBOSE and status:
if status:
print(status)

speech = np.concatenate((speech, chunk))
Expand All @@ -160,18 +161,15 @@ def print_captions(text, new_cached_caption=False):

if "end" in speech_dict and recording:
recording = False
end_recording(speech, "<STOP>")
end_recording(speech)

elif recording:
# Possible speech truncation can cause hallucination.

if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
recording = False
end_recording(speech, "<SNIP>")
# Soft reset without affecting VAD model state.
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
end_recording(speech)
soft_reset(vad_iterator)

if (time.time() - start_time) > MIN_REFRESH_SECS:
print_captions(transcribe(speech))
Expand All @@ -184,7 +182,7 @@ def print_captions(text, new_cached_caption=False):
while not q.empty():
chunk, _ = q.get()
speech = np.concatenate((speech, chunk))
end_recording(speech, "<END.>")
end_recording(speech)

print(f"""

Expand All @@ -196,7 +194,4 @@ def print_captions(text, new_cached_caption=False):
model realtime factor : {(transcribe.speech_secs / transcribe.inference_secs):0.2f}x
""")
if caption_cache:
print("Cached captions.")
for caption in caption_cache:
print(caption[:-MARKER_LENGTH], end="", flush=True)
print("")
print(f"Cached captions.\n{' '.join(caption_cache)}")