Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LocalWhisperTranscriber (v2) #4909

Merged
merged 20 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions e2e/preview/components/test_whisper_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber


def test_whisper_local_transcriber(preview_samples_path):
comp = LocalWhisperTranscriber(model_name_or_path="tiny")
docs = comp.transcribe(
audio_files=[
preview_samples_path / "audio" / "this is the content of the document.wav",
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()),
open(preview_samples_path / "audio" / "answer.wav", "rb"),
]
)
assert len(docs) == 3

assert "this is the content of the document." == docs[0].content.strip().lower()
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].content.strip().lower()
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].content.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
1 change: 1 addition & 0 deletions haystack/preview/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
Empty file.
129 changes: 129 additions & 0 deletions haystack/preview/components/audio/whisper_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence

import logging
from pathlib import Path
from dataclasses import dataclass

import torch
import whisper

from haystack.preview import component, ComponentInput, ComponentOutput, Document


logger = logging.getLogger(__name__)
WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"]


@component
class LocalWhisperTranscriber:
"""
Transcribes audio files using OpenAI's Whisper's model on your local machine.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
"""

@dataclass
class Input(ComponentInput):
audio_files: List[Path]
whisper_params: Optional[Dict[str, Any]] = None

@dataclass
class Output(ComponentOutput):
documents: List[Document]

def __init__(self, model_name_or_path: WhisperLocalModel = "large", device: Optional[str] = None):
"""
:param model_name_or_path: Name of the model to use. Set it to one of the following values:
- `tiny`
- `small`
- `medium`
- `large`
- `large-v2`
:param device: Name of the torch device to use for inference. If None, CPU is used.
"""
if model_name_or_path not in get_args(WhisperLocalModel):
raise ValueError(
f"Model name '{model_name_or_path}' not recognized. Choose one among: "
f"{', '.join(get_args(WhisperLocalModel))}."
)
self.model_name = model_name_or_path
self.device = torch.device(device) if device else torch.device("cpu")
self._model = None

def warm_up(self) -> None:
"""
Loads the model.
"""
if not self._model:
self._model = whisper.load_model(self.model_name, device=self.device)

def run(self, data: Input) -> Output:
"""
Transcribe the audio files into a list of Documents, one for each input file.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
while the document's metadata contains all the other values returned by the Whisper model, such as the
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
"""
if not data.whisper_params:
data.whisper_params = {}
documents = self.transcribe(data.audio_files, **data.whisper_params)
return LocalWhisperTranscriber.Output(documents)

def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
"""
Transcribe the audio files into a list of Documents, one for each input file.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
while the document's metadata contains all the other values returned by the Whisper model, such as the
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
"""
transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs)
documents = []
for audio, transcript in zip(audio_files, transcriptions):
content = transcript.pop("text")
if not isinstance(audio, (str, Path)):
audio = "<<binary stream>>"
doc = Document(content=content, metadata={"audio_file": audio, **transcript})
documents.append(doc)
return documents

def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
"""
Transcribe the given audio files. Returns the output of the model, a dictionary, for each input file.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
"""
self.warm_up()
return_segments = kwargs.pop("return_segments", False)
transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
audio_file = open(audio_file, "rb")

# mypy compains that _model is not guaranteed to be not None. It is: check self.warm_up()
transcription = self._model.transcribe(audio_file.name, **kwargs) # type: ignore
if not return_segments:
transcription.pop("segments", None)
transcriptions.append(transcription)

return transcriptions
132 changes: 132 additions & 0 deletions test/preview/components/audio/test_whisper_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from pathlib import Path
from unittest.mock import patch, MagicMock

import pytest
import torch

from haystack.preview.dataclasses import Document
from haystack.preview.components import LocalWhisperTranscriber

from test.preview.components.base import BaseTestComponent


SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"


class Test_LocalWhisperTranscriber(BaseTestComponent):
@pytest.mark.unit
def test_save_load(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(
LocalWhisperTranscriber(model_name_or_path="large-v2"), tmp_path
)

@pytest.mark.unit
def test_init(self):
transcriber = LocalWhisperTranscriber(
model_name_or_path="large-v2"
) # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.model_name == "large-v2"
assert transcriber.device == torch.device("cpu")
assert transcriber._model is None

@pytest.mark.unit
def test_init_wrong_model(self):
with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"):
LocalWhisperTranscriber(model_name_or_path="whisper-1")

@pytest.mark.unit
def test_warmup(self):
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:
transcriber = LocalWhisperTranscriber(model_name_or_path="large-v2")
mocked_whisper.load_model.assert_not_called()
transcriber.warm_up()
mocked_whisper.load_model.assert_called_once_with("large-v2", device=torch.device(type="cpu"))

@pytest.mark.unit
def test_warmup_doesnt_reload(self):
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:
transcriber = LocalWhisperTranscriber(model_name_or_path="large-v2")
transcriber.warm_up()
transcriber.warm_up()
mocked_whisper.load_model.assert_called_once()

@pytest.mark.unit
def test_run_with_path(self):
comp = LocalWhisperTranscriber(model_name_or_path="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.run(
LocalWhisperTranscriber.Input(
audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]
)
)
expected = Document(
content="test transcription",
metadata={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert isinstance(results, LocalWhisperTranscriber.Output)
assert results.documents == [expected]

@pytest.mark.unit
def test_run_with_str(self):
comp = LocalWhisperTranscriber(model_name_or_path="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.run(
LocalWhisperTranscriber.Input(
audio_files=[str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute())]
)
)
expected = Document(
content="test transcription",
metadata={
"audio_file": str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute()),
"other_metadata": ["other", "meta", "data"],
},
)
assert isinstance(results, LocalWhisperTranscriber.Output)
assert results.documents == [expected]

@pytest.mark.unit
def test_transcribe(self):
comp = LocalWhisperTranscriber(model_name_or_path="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
expected = Document(
content="test transcription",
metadata={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
},
)
assert results == [expected]

@pytest.mark.unit
def test_transcribe_stream(self):
comp = LocalWhisperTranscriber(model_name_or_path="large-v2")
comp._model = MagicMock()
comp._model.transcribe.return_value = {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
}
results = comp.transcribe(
audio_files=[open(SAMPLES_PATH / "audio" / "this is the content of the document.wav", "rb")]
)
expected = Document(
content="test transcription",
metadata={"audio_file": "<<binary stream>>", "other_metadata": ["other", "meta", "data"]},
)
assert results == [expected]