-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: WhisperTranscriber
for v2
#4723
Changes from all commits
146138e
f334723
edfabfc
1c87b35
4fbf74a
ffcf13e
ec6b2d4
c605a6b
c4fd0a7
e2f08be
87206e1
a3062d7
253b324
7091115
2c9775d
3f3ef94
a23b56c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from haystack.preview.components import WhisperTranscriber | ||
|
||
|
||
SAMPLES_PATH = Path(__file__).parent.parent / "test_files" | ||
|
||
|
||
def test_raw_transcribe_local(): | ||
comp = WhisperTranscriber(model_name_or_path="tiny") | ||
output = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]) | ||
assert "this is the content of the document" in output[0]["text"].lower() | ||
|
||
|
||
def test_raw_transcribe_remote(): | ||
comp = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) | ||
output = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]) | ||
assert "this is the content of the document" in output[0]["text"].lower() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from haystack.preview.components.audio.transcriber import WhisperTranscriber |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from haystack.preview.components.audio.transcriber import WhisperTranscriber |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence | ||
|
||
import os | ||
import json | ||
import logging | ||
from pathlib import Path | ||
from dataclasses import dataclass | ||
|
||
import requests | ||
import torch | ||
import whisper | ||
from tenacity import retry, wait_exponential, retry_if_not_result | ||
|
||
from haystack.preview import component, Document | ||
from haystack.errors import OpenAIError, OpenAIRateLimitError | ||
from haystack import is_imported | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 30)) | ||
|
||
|
||
WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"] | ||
WhisperRemoteModel = Literal["whisper-1"] | ||
WhisperModel = Union[WhisperLocalModel, WhisperRemoteModel] | ||
|
||
|
||
@component | ||
class WhisperTranscriber: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of splitting this into two different components? One local only and one remote, something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would simplify the logic quite a bit, also the testing. If we go this way I think it would be better to open two separate PRs and ditch this one. |
||
""" | ||
Transcribes audio files using OpenAI's Whisper. This class supports two underlying implementations: | ||
|
||
- API (default): Uses the OpenAI API and requires an API key. See the | ||
[OpenAI blog post](https://beta.openai.com/docs/api-reference/whisper for more details. | ||
|
||
- Local (requires installing Whisper): Uses the local installation | ||
of [Whisper](https://github.com/openai/whisper). | ||
|
||
To use Whisper locally, install it following the instructions on the Whisper | ||
[GitHub repo](https://github.com/openai/whisper) and omit the `api_key` parameter. | ||
|
||
To use the API implementation, provide an API key. You can get one by signing up for an | ||
[OpenAI account](https://beta.openai.com/). | ||
|
||
For the supported audio formats, languages, and other parameters, see the | ||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
[github repo](https://github.com/openai/whisper). | ||
""" | ||
|
||
@dataclass | ||
class Output: | ||
documents: List[Document] | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: WhisperModel = "whisper-1", | ||
api_key: Optional[str] = None, | ||
device: Optional[str] = None, | ||
): | ||
""" | ||
Transcribes a list of audio files into a list of Documents. | ||
|
||
:param api_key: OpenAI API key. If None, a local installation of Whisper is used. | ||
:param model_name_or_path: Name of the model to use. If using a local installation of Whisper, set this to one | ||
of the following values: | ||
- `tiny` | ||
- `small` | ||
- `medium` | ||
- `large` | ||
- `large-v2` | ||
If using the API, set this value to: | ||
- `whisper-1` (default) | ||
:param device: Device to use for inference. Only used if you're using a local installation of Whisper. | ||
If None, CPU is used. | ||
""" | ||
if model_name_or_path not in (get_args(WhisperRemoteModel) + get_args(WhisperLocalModel)): | ||
raise ValueError( | ||
f"Model name not recognized. Choose one among: " | ||
f"{', '.join(get_args(WhisperRemoteModel) + get_args(WhisperLocalModel))}." | ||
) | ||
|
||
if model_name_or_path in get_args(WhisperRemoteModel) and not api_key: | ||
raise ValueError( | ||
"Provide a valid API key for OpenAI API. Alternatively, install OpenAI Whisper (see " | ||
"[Whisper](https://github.com/openai/whisper) for more details) " | ||
f"and select a model size among: {', '.join(get_args(WhisperLocalModel))}" | ||
) | ||
|
||
if model_name_or_path in get_args(WhisperLocalModel) and not is_imported("whisper"): | ||
raise ValueError( | ||
"To use a local Whisper model, install Haystack's audio extras as `pip install farm-haystack[audio]` " | ||
"or install Whisper yourself with `pip install openai-whisper`. You will need ffmpeg on your system " | ||
"in either case, see: https://github.com/openai/whisper." | ||
) | ||
|
||
if model_name_or_path in get_args(WhisperLocalModel) and api_key: | ||
logger.warning( | ||
"An API Key was provided, but a local model was selected. " | ||
"WhisperTranscriber will try to use the local model." | ||
) | ||
|
||
self.api_key = api_key | ||
self.model_name = model_name_or_path | ||
self.use_local_whisper = model_name_or_path in get_args(WhisperLocalModel) | ||
|
||
if self.use_local_whisper: | ||
self.device = torch.device(device) if device else torch.device("cpu") | ||
|
||
self._model = None | ||
if not self.use_local_whisper and api_key is None: | ||
raise ValueError( | ||
"Provide a valid API key for OpenAI API. Alternatively, install OpenAI Whisper (see " | ||
"[Whisper](https://github.com/openai/whisper) for more details)." | ||
) | ||
|
||
def warm_up(self): | ||
""" | ||
If we're using a local model, load it here. | ||
""" | ||
if self.use_local_whisper and not self._model: | ||
self._model = whisper.load_model(self.model_name, device=self.device) | ||
|
||
def run(self, audios: List[Path], whisper_params: Dict[str, Any]) -> Output: | ||
documents = self.transcribe_to_documents(audios, **whisper_params) | ||
return WhisperTranscriber.Output(documents) | ||
|
||
def transcribe_to_documents(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]: | ||
""" | ||
Transcribe the given audio files. Returns a list of Documents. | ||
|
||
For the supported audio formats, languages, and other parameters, see the | ||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
[github repo](https://github.com/openai/whisper). | ||
|
||
:param audio_files: a list of paths or binary streams to transcribe | ||
:returns: a list of transcriptions. | ||
""" | ||
transcriptions = self.transcribe(audio_files=audio_files, **kwargs) | ||
return [ | ||
Document(content=transcript.pop("text"), metadata={"audio_file": audio, **transcript}) | ||
for audio, transcript in zip(audio_files, transcriptions) | ||
] | ||
|
||
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]: | ||
""" | ||
Transcribe the given audio files. Returns a list of strings. | ||
|
||
For the supported audio formats, languages, and other parameters, see the | ||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper | ||
[github repo](https://github.com/openai/whisper). | ||
|
||
:param audio_files: a list of paths or binary streams to transcribe | ||
:returns: a list of transcriptions. | ||
""" | ||
transcriptions = [] | ||
for audio_file in audio_files: | ||
if isinstance(audio_file, (str, Path)): | ||
audio_file = open(audio_file, "rb") | ||
|
||
if self.use_local_whisper: | ||
transcription = self._transcribe_locally(audio_file, **kwargs) | ||
else: | ||
transcription = self._transcribe_with_api(audio_file, **kwargs) | ||
|
||
transcriptions.append(transcription) | ||
return transcriptions | ||
|
||
@retry(retry=retry_if_not_result(bool), wait=wait_exponential(min=1, max=10)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should use the |
||
def _transcribe_with_api(self, audio_file: BinaryIO, **kwargs) -> Dict[str, Any]: | ||
""" | ||
Calls a remote Whisper model through OpenAI Whisper API. | ||
""" | ||
translate = kwargs.pop("translate", False) | ||
|
||
response = requests.post( | ||
url=f"https://api.openai.com/v1/audio/{'translations' if translate else 'transcriptions'}", | ||
data={"model": "whisper-1", **kwargs}, | ||
headers={"Authorization": f"Bearer {self.api_key}"}, | ||
files=[("file", (audio_file.name, audio_file, "application/octet-stream"))], | ||
timeout=600, | ||
) | ||
|
||
if response.status_code != 200: | ||
if response.status_code == 429: | ||
raise OpenAIRateLimitError(f"API rate limit exceeded: {response.text}") | ||
raise OpenAIError( | ||
f"OpenAI returned an error.\n" | ||
f"Status code: {response.status_code}\n" | ||
f"Response body: {response.text}", | ||
status_code=response.status_code, | ||
) | ||
|
||
return json.loads(response.content) | ||
|
||
def _transcribe_locally(self, audio_file: BinaryIO, **kwargs) -> Dict[str, Any]: | ||
""" | ||
Calls a local Whisper model. | ||
""" | ||
if not self._model: | ||
self.warm_up() | ||
if not self._model: | ||
raise ValueError("WhisperTranscriber._transcribe_locally() can't work without a local model.") | ||
return_segments = kwargs.pop("return_segments", None) | ||
transcription = self._model.transcribe(audio_file.name, **kwargs) | ||
if not return_segments: | ||
transcription.pop("segments", None) | ||
return transcription |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from haystack.preview.components.audio.transcriber import WhisperTranscriber |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from canals.testing import BaseTestComponent as CanalsBaseTestComponent | ||
|
||
|
||
class BaseTestComponent(CanalsBaseTestComponent): | ||
""" | ||
Base tests for Haystack components. | ||
""" | ||
|
||
@pytest.fixture | ||
def request_mock(self, monkeypatch): | ||
request_mock = MagicMock() | ||
monkeypatch.setattr("requests.request", MagicMock()) | ||
return request_mock |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import os | ||
import sys | ||
from pathlib import Path | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
import torch | ||
import whisper | ||
from generalimport import FakeModule | ||
|
||
from haystack.preview.dataclasses import Document | ||
from haystack.preview.components import WhisperTranscriber | ||
|
||
from test.preview.components.test_component_base import BaseTestComponent | ||
|
||
|
||
SAMPLES_PATH = Path(__file__).parent.parent / "test_files" | ||
|
||
|
||
class TestTranscriber(BaseTestComponent): | ||
""" | ||
Tests for WhisperTranscriber. | ||
""" | ||
|
||
@pytest.fixture | ||
def components(self): | ||
return [WhisperTranscriber(api_key="just a test"), WhisperTranscriber(model_name_or_path="large-v2")] | ||
|
||
@pytest.fixture | ||
def mock_models(self, monkeypatch): | ||
def mock_transcribe(_, audio_file, **kwargs): | ||
return { | ||
"text": "test transcription", | ||
"other_metadata": ["other", "meta", "data"], | ||
"kwargs received": kwargs, | ||
} | ||
|
||
monkeypatch.setattr(WhisperTranscriber, "_transcribe_with_api", mock_transcribe) | ||
monkeypatch.setattr(WhisperTranscriber, "_transcribe_locally", mock_transcribe) | ||
monkeypatch.setattr(WhisperTranscriber, "warm_up", lambda self: None) | ||
|
||
@pytest.mark.unit | ||
def test_init_remote_unknown_model(self): | ||
with pytest.raises(ValueError, match="not recognized"): | ||
WhisperTranscriber(model_name_or_path="anything") | ||
|
||
@pytest.mark.unit | ||
def test_init_default_remote_missing_key(self): | ||
with pytest.raises(ValueError, match="API key"): | ||
WhisperTranscriber() | ||
|
||
@pytest.mark.unit | ||
def test_init_explicit_remote_missing_key(self): | ||
with pytest.raises(ValueError, match="API key"): | ||
WhisperTranscriber(model_name_or_path="whisper-1") | ||
|
||
@pytest.mark.unit | ||
def test_init_remote(self): | ||
transcriber = WhisperTranscriber(api_key="just a test") | ||
assert transcriber.model_name == "whisper-1" | ||
assert not transcriber.use_local_whisper | ||
assert not hasattr(transcriber, "device") | ||
assert hasattr(transcriber, "_model") and transcriber._model is None | ||
|
||
@pytest.mark.unit | ||
def test_init_local(self): | ||
transcriber = WhisperTranscriber(model_name_or_path="large-v2") | ||
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init. | ||
assert transcriber.use_local_whisper | ||
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu") | ||
assert hasattr(transcriber, "_model") and transcriber._model is None | ||
|
||
@pytest.mark.unit | ||
def test_init_local_with_api_key(self): | ||
transcriber = WhisperTranscriber(model_name_or_path="large-v2") | ||
assert transcriber.model_name == "large-v2" # Doesn't matter if it's huge, the model is not loaded in init. | ||
assert transcriber.use_local_whisper | ||
assert hasattr(transcriber, "device") and transcriber.device == torch.device("cpu") | ||
assert hasattr(transcriber, "_model") and transcriber._model is None | ||
|
||
@pytest.mark.unit | ||
def test_init_missing_whisper_lib_local_model(self, monkeypatch): | ||
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test")) | ||
with pytest.raises(ValueError, match="audio extra"): | ||
WhisperTranscriber(model_name_or_path="large-v2") | ||
|
||
@pytest.mark.unit | ||
def test_init_missing_whisper_lib_remote_model(self, monkeypatch): | ||
monkeypatch.setitem(sys.modules, "whisper", FakeModule(spec=MagicMock(), message="test")) | ||
# Should not fail if the lib is missing and we're using API | ||
WhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter") | ||
|
||
@pytest.mark.unit | ||
def test_warmup_remote_model(self, monkeypatch): | ||
load_model = MagicMock() | ||
monkeypatch.setattr(whisper, "load_model", load_model) | ||
component = WhisperTranscriber(model_name_or_path="whisper-1", api_key="doesn't matter") | ||
component.warm_up() | ||
assert not load_model.called | ||
|
||
@pytest.mark.unit | ||
def test_warmup_local_model(self, monkeypatch): | ||
load_model = MagicMock() | ||
load_model.side_effect = ["FAKE MODEL"] | ||
monkeypatch.setattr(whisper, "load_model", load_model) | ||
|
||
component = WhisperTranscriber(model_name_or_path="large-v2") | ||
component.warm_up() | ||
|
||
assert hasattr(component, "_model") | ||
assert component._model == "FAKE MODEL" | ||
load_model.assert_called_with("large-v2", device=torch.device(type="cpu")) | ||
|
||
@pytest.mark.unit | ||
def test_warmup_local_model_doesnt_reload(self, monkeypatch): | ||
load_model = MagicMock() | ||
monkeypatch.setattr(whisper, "load_model", load_model) | ||
component = WhisperTranscriber(model_name_or_path="large-v2") | ||
component.warm_up() | ||
component.warm_up() | ||
load_model.assert_called_once() | ||
|
||
@pytest.mark.unit | ||
def test_transcribe_to_documents(self, mock_models): | ||
comp = WhisperTranscriber(model_name_or_path="large-v2") | ||
output = comp.transcribe_to_documents( | ||
audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"] | ||
) | ||
assert output == [ | ||
Document( | ||
content="test transcription", | ||
metadata={ | ||
"audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav", | ||
"other_metadata": ["other", "meta", "data"], | ||
"kwargs received": {}, | ||
}, | ||
) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never used.