Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: WhisperTranscriber for v2 #4723

Closed
wants to merge 17 commits into from
19 changes: 19 additions & 0 deletions e2e/preview/components/test_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from pathlib import Path

from haystack.preview.components import WhisperTranscriber


SAMPLES_PATH = Path(__file__).parent.parent / "test_files"


def test_raw_transcribe_local():
comp = WhisperTranscriber(model_name_or_path="tiny")
output = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
assert "this is the content of the document" in output[0]["text"].lower()


def test_raw_transcribe_remote():
comp = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY"))
output = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"])
assert "this is the content of the document" in output[0]["text"].lower()
Binary file not shown.
1 change: 1 addition & 0 deletions haystack/preview/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from haystack.preview.components.audio.transcriber import WhisperTranscriber
1 change: 1 addition & 0 deletions haystack/preview/components/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from haystack.preview.components.audio.transcriber import WhisperTranscriber
209 changes: 209 additions & 0 deletions haystack/preview/components/audio/transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence

import os
import json
import logging
from pathlib import Path
from dataclasses import dataclass

import requests
import torch
import whisper
from tenacity import retry, wait_exponential, retry_if_not_result

from haystack.preview import component, Document
from haystack.errors import OpenAIError, OpenAIRateLimitError
from haystack import is_imported


logger = logging.getLogger(__name__)


OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 30))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never used.



WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"]
WhisperRemoteModel = Literal["whisper-1"]
WhisperModel = Union[WhisperLocalModel, WhisperRemoteModel]


@component
class WhisperTranscriber:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of splitting this into two different components? One local only and one remote, something like WhisperLocalTranscriber and WhisperOpenAITranscriber. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would simplify the logic quite a bit, also the testing.

If we go this way I think it would be better to open two separate PRs and ditch this one.

"""
Transcribes audio files using OpenAI's Whisper. This class supports two underlying implementations:

- API (default): Uses the OpenAI API and requires an API key. See the
[OpenAI blog post](https://beta.openai.com/docs/api-reference/whisper for more details.

- Local (requires installing Whisper): Uses the local installation
of [Whisper](https://github.com/openai/whisper).

To use Whisper locally, install it following the instructions on the Whisper
[GitHub repo](https://github.com/openai/whisper) and omit the `api_key` parameter.

To use the API implementation, provide an API key. You can get one by signing up for an
[OpenAI account](https://beta.openai.com/).

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
"""

@dataclass
class Output:
documents: List[Document]

def __init__(
self,
model_name_or_path: WhisperModel = "whisper-1",
api_key: Optional[str] = None,
device: Optional[str] = None,
):
"""
Transcribes a list of audio files into a list of Documents.

:param api_key: OpenAI API key. If None, a local installation of Whisper is used.
:param model_name_or_path: Name of the model to use. If using a local installation of Whisper, set this to one
of the following values:
- `tiny`
- `small`
- `medium`
- `large`
- `large-v2`
If using the API, set this value to:
- `whisper-1` (default)
:param device: Device to use for inference. Only used if you're using a local installation of Whisper.
If None, CPU is used.
"""
if model_name_or_path not in (get_args(WhisperRemoteModel) + get_args(WhisperLocalModel)):
raise ValueError(
f"Model name not recognized. Choose one among: "
f"{', '.join(get_args(WhisperRemoteModel) + get_args(WhisperLocalModel))}."
)

if model_name_or_path in get_args(WhisperRemoteModel) and not api_key:
raise ValueError(
"Provide a valid API key for OpenAI API. Alternatively, install OpenAI Whisper (see "
"[Whisper](https://github.com/openai/whisper) for more details) "
f"and select a model size among: {', '.join(get_args(WhisperLocalModel))}"
)

if model_name_or_path in get_args(WhisperLocalModel) and not is_imported("whisper"):
raise ValueError(
"To use a local Whisper model, install Haystack's audio extras as `pip install farm-haystack[audio]` "
"or install Whisper yourself with `pip install openai-whisper`. You will need ffmpeg on your system "
"in either case, see: https://github.com/openai/whisper."
)

if model_name_or_path in get_args(WhisperLocalModel) and api_key:
logger.warning(
"An API Key was provided, but a local model was selected. "
"WhisperTranscriber will try to use the local model."
)

self.api_key = api_key
self.model_name = model_name_or_path
self.use_local_whisper = model_name_or_path in get_args(WhisperLocalModel)

if self.use_local_whisper:
self.device = torch.device(device) if device else torch.device("cpu")

self._model = None
if not self.use_local_whisper and api_key is None:
raise ValueError(
"Provide a valid API key for OpenAI API. Alternatively, install OpenAI Whisper (see "
"[Whisper](https://github.com/openai/whisper) for more details)."
)

def warm_up(self):
"""
If we're using a local model, load it here.
"""
if self.use_local_whisper and not self._model:
self._model = whisper.load_model(self.model_name, device=self.device)

def run(self, audios: List[Path], whisper_params: Dict[str, Any]) -> Output:
documents = self.transcribe_to_documents(audios, **whisper_params)
return WhisperTranscriber.Output(documents)

def transcribe_to_documents(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
"""
Transcribe the given audio files. Returns a list of Documents.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
"""
transcriptions = self.transcribe(audio_files=audio_files, **kwargs)
return [
Document(content=transcript.pop("text"), metadata={"audio_file": audio, **transcript})
for audio, transcript in zip(audio_files, transcriptions)
]

def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
"""
Transcribe the given audio files. Returns a list of strings.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
"""
transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
audio_file = open(audio_file, "rb")

if self.use_local_whisper:
transcription = self._transcribe_locally(audio_file, **kwargs)
else:
transcription = self._transcribe_with_api(audio_file, **kwargs)

transcriptions.append(transcription)
return transcriptions

@retry(retry=retry_if_not_result(bool), wait=wait_exponential(min=1, max=10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use the request_with_retry util method in here.

def _transcribe_with_api(self, audio_file: BinaryIO, **kwargs) -> Dict[str, Any]:
"""
Calls a remote Whisper model through OpenAI Whisper API.
"""
translate = kwargs.pop("translate", False)

response = requests.post(
url=f"https://api.openai.com/v1/audio/{'translations' if translate else 'transcriptions'}",
data={"model": "whisper-1", **kwargs},
headers={"Authorization": f"Bearer {self.api_key}"},
files=[("file", (audio_file.name, audio_file, "application/octet-stream"))],
timeout=600,
)

if response.status_code != 200:
if response.status_code == 429:
raise OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
raise OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}",
status_code=response.status_code,
)

return json.loads(response.content)

def _transcribe_locally(self, audio_file: BinaryIO, **kwargs) -> Dict[str, Any]:
"""
Calls a local Whisper model.
"""
if not self._model:
self.warm_up()
if not self._model:
raise ValueError("WhisperTranscriber._transcribe_locally() can't work without a local model.")
return_segments = kwargs.pop("return_segments", None)
transcription = self._model.transcribe(audio_file.name, **kwargs)
if not return_segments:
transcription.pop("segments", None)
return transcription
1 change: 1 addition & 0 deletions test/preview/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from haystack.preview.components.audio.transcriber import WhisperTranscriber
16 changes: 16 additions & 0 deletions test/preview/components/test_component_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from unittest.mock import MagicMock

import pytest
from canals.testing import BaseTestComponent as CanalsBaseTestComponent


class BaseTestComponent(CanalsBaseTestComponent):
"""
Base tests for Haystack components.
"""

@pytest.fixture
def request_mock(self, monkeypatch):
request_mock = MagicMock()
monkeypatch.setattr("requests.request", MagicMock())
return request_mock
138 changes: 138 additions & 0 deletions test/preview/components/test_transcriber_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import sys
from pathlib import Path
from unittest.mock import MagicMock

import pytest
import torch
import whisper
from generalimport import FakeModule

from haystack.preview.dataclasses import Document
from haystack.preview.components import WhisperTranscriber

from test.preview.components.test_component_base import BaseTestComponent


SAMPLES_PATH = Path(__file__).parent.parent / "test_files"


class TestTranscriber(BaseTestComponent):
"""
Tests for WhisperTranscriber.
"""

@pytest.fixture
def components(self):
return [WhisperTranscriber(api_key="just a test"), WhisperTranscriber(model_name_or_path="large-v2")]

@pytest.fixture
def mock_models(self, monkeypatch):
def mock_transcribe(_, audio_file, **kwargs):
return {
"text": "test transcription",
"other_metadata": ["other", "meta", "data"],
"kwargs received": kwargs,
}

monkeypatch.setattr(WhisperTranscriber, "_transcribe_with_api", mock_transcribe)
monkeypatch.setattr(WhisperTranscriber, "_transcribe_locally", mock_transcribe)
monkeypatch.setattr(WhisperTranscriber, "warm_up", lambda self: None)

@pytest.mark.unit
def test_init_remote_unknown_model(self):
with pytest.raises(ValueError, match="not recognized"):
WhisperTranscriber(model_name_or_path="anything")

@pytest.mark.unit
def test_init_default_remote_missing_key(self):
with pytest.raises(ValueError, match="API key"):
WhisperTranscriber()

@pytest.mark.unit
def test_init_explicit_remote_missing_key(self):
with pytest.raises(ValueError, match="API key"):
WhisperTranscriber(model_name_or_path="whisper-1")

@pytest.mark.unit
def test_init_remote(self):
transcriber = WhisperTranscriber(api_key="just a test")
assert transcriber.model_name == "whisper-1"
assert not transcriber.use_local_whisper
assert not hasattr(transcriber, "device")
assert hasattr(transcriber, "_model") and transcriber._model is None

@pytest.mark.unit
def test_init_local(self):
transcriber = WhisperTranscriber(model_name_or_path="large-v2")
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.use_local_whisper
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu")
assert hasattr(transcriber, "_model") and transcriber._model is None

@pytest.mark.unit
def test_init_local_with_api_key(self):
transcriber = WhisperTranscriber(model_name_or_path="large-v2")
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init.
assert transcriber.use_local_whisper
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu")
assert hasattr(transcriber, "_model") and transcriber._model is None

@pytest.mark.unit
def test_init_missing_whisper_lib_local_model(self, monkeypatch):
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test"))
with pytest.raises(ValueError, match="audio extra"):
WhisperTranscriber(model_name_or_path="large-v2")

@pytest.mark.unit
def test_init_missing_whisper_lib_remote_model(self, monkeypatch):
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test"))
# Should not fail if the lib is missing and we're using API
WhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter")

@pytest.mark.unit
def test_warmup_remote_model(self, monkeypatch):
load_model = MagicMock()
monkeypatch.setattr(whisper, "load_model", load_model)
component = WhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter")
component.warm_up()
assert not load_model.called

@pytest.mark.unit
def test_warmup_local_model(self, monkeypatch):
load_model = MagicMock()
load_model.side_effect = ["FAKE MODEL"]
monkeypatch.setattr(whisper, "load_model", load_model)

component = WhisperTranscriber(model_name_or_path="large-v2")
component.warm_up()

assert hasattr(component, "_model")
assert component._model == "FAKE MODEL"
load_model.assert_called_with("large-v2", device=torch.device(type="cpu"))

@pytest.mark.unit
def test_warmup_local_model_doesnt_reload(self, monkeypatch):
load_model = MagicMock()
monkeypatch.setattr(whisper, "load_model", load_model)
component = WhisperTranscriber(model_name_or_path="large-v2")
component.warm_up()
component.warm_up()
load_model.assert_called_once()

@pytest.mark.unit
def test_transcribe_to_documents(self, mock_models):
comp = WhisperTranscriber(model_name_or_path="large-v2")
output = comp.transcribe_to_documents(
audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]
)
assert output == [
Document(
content="test transcription",
metadata={
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav",
"other_metadata": ["other", "meta", "data"],
"kwargs received": {},
},
)
]
Binary file not shown.