-
Notifications
You must be signed in to change notification settings - Fork 126
/
live_captions.py
187 lines (146 loc) · 5.72 KB
/
live_captions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Live captions from microphone using Moonshine and SileroVAD ONNX models."""
import argparse
import os
import time
from queue import Queue
import numpy as np
from silero_vad import VADIterator, load_silero_vad
from sounddevice import InputStream
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
SAMPLING_RATE = 16000
CHUNK_SIZE = 512 # Silero VAD requirement with sampling rate 16000.
LOOKBACK_CHUNKS = 5
MAX_LINE_LENGTH = 80
# These affect live caption updating - adjust for your platform speed and model.
MAX_SPEECH_SECS = 15
MIN_REFRESH_SECS = 0.2
class Transcriber(object):
def __init__(self, model_name, rate=16000):
if rate != 16000:
raise ValueError("Moonshine supports sampling rate 16000 Hz.")
self.model = MoonshineOnnxModel(model_name=model_name)
self.rate = rate
self.tokenizer = load_tokenizer()
self.inference_secs = 0
self.number_inferences = 0
self.speech_secs = 0
self.__call__(np.zeros(int(rate), dtype=np.float32)) # Warmup.
def __call__(self, speech):
"""Returns string containing Moonshine transcription of speech."""
self.number_inferences += 1
self.speech_secs += len(speech) / self.rate
start_time = time.time()
tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
text = self.tokenizer.decode_batch(tokens)[0]
self.inference_secs += time.time() - start_time
return text
def create_input_callback(q):
"""Callback method for sounddevice InputStream."""
def input_callback(data, frames, time, status):
if status:
print(status)
q.put((data.copy().flatten(), status))
return input_callback
def end_recording(speech, do_print=True):
"""Transcribes, prints and caches the caption then clears speech buffer."""
text = transcribe(speech)
if do_print:
print_captions(text)
caption_cache.append(text)
speech *= 0.0
def print_captions(text):
"""Prints right justified on same line, prepending cached captions."""
if len(text) < MAX_LINE_LENGTH:
for caption in caption_cache[::-1]:
text = caption + " " + text
if len(text) > MAX_LINE_LENGTH:
break
if len(text) > MAX_LINE_LENGTH:
text = text[-MAX_LINE_LENGTH:]
else:
text = " " * (MAX_LINE_LENGTH - len(text)) + text
print("\r" + (" " * MAX_LINE_LENGTH) + "\r" + text, end="", flush=True)
def soft_reset(vad_iterator):
"""Soft resets Silero VADIterator without affecting VAD model state."""
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="live_captions",
description="Live captioning demo of Moonshine models",
)
parser.add_argument(
"--model_name",
help="Model to run the demo with",
default="moonshine/base",
choices=["moonshine/base", "moonshine/tiny"],
)
args = parser.parse_args()
model_name = args.model_name
print(f"Loading Moonshine model '{model_name}' (using ONNX runtime) ...")
transcribe = Transcriber(model_name=model_name, rate=SAMPLING_RATE)
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(
model=vad_model,
sampling_rate=SAMPLING_RATE,
threshold=0.5,
min_silence_duration_ms=300,
)
q = Queue()
stream = InputStream(
samplerate=SAMPLING_RATE,
channels=1,
blocksize=CHUNK_SIZE,
dtype=np.float32,
callback=create_input_callback(q),
)
stream.start()
caption_cache = []
lookback_size = LOOKBACK_CHUNKS * CHUNK_SIZE
speech = np.empty(0, dtype=np.float32)
recording = False
print("Press Ctrl+C to quit live captions.\n")
with stream:
print_captions("Ready...")
try:
while True:
chunk, status = q.get()
if status:
print(status)
speech = np.concatenate((speech, chunk))
if not recording:
speech = speech[-lookback_size:]
speech_dict = vad_iterator(chunk)
if speech_dict:
if "start" in speech_dict and not recording:
recording = True
start_time = time.time()
if "end" in speech_dict and recording:
recording = False
end_recording(speech)
elif recording:
# Possible speech truncation can cause hallucination.
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
recording = False
end_recording(speech)
soft_reset(vad_iterator)
if (time.time() - start_time) > MIN_REFRESH_SECS:
print_captions(transcribe(speech))
start_time = time.time()
except KeyboardInterrupt:
stream.close()
if recording:
while not q.empty():
chunk, _ = q.get()
speech = np.concatenate((speech, chunk))
end_recording(speech, do_print=False)
print(f"""
model_name : {model_name}
MIN_REFRESH_SECS : {MIN_REFRESH_SECS}s
number inferences : {transcribe.number_inferences}
mean inference time : {(transcribe.inference_secs / transcribe.number_inferences):.2f}s
model realtime factor : {(transcribe.speech_secs / transcribe.inference_secs):0.2f}x
""")
if caption_cache:
print(f"Cached captions.\n{' '.join(caption_cache)}")