-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from dvejsada/main
Proposed refactor of the custom component
- Loading branch information
Showing
10 changed files
with
244 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,22 @@ | ||
"""Custom integration for OpenAI TTS.""" | ||
"""Custom integration for OpenAI TTS.""" | ||
from __future__ import annotations | ||
|
||
from homeassistant.const import Platform | ||
from homeassistant.config_entries import ConfigEntry | ||
from homeassistant.core import HomeAssistant | ||
|
||
PLATFORMS: list[str] = [Platform.TTS] | ||
|
||
|
||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
"""Set up entities.""" | ||
|
||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) | ||
return True | ||
|
||
|
||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: | ||
"""Unload a config entry.""" | ||
|
||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Config flow for OpenAI text-to-speech custom component.""" | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import voluptuous as vol | ||
import logging | ||
|
||
from homeassistant.config_entries import ConfigFlow | ||
from homeassistant.helpers.selector import selector | ||
from homeassistant.exceptions import HomeAssistantError | ||
|
||
from .const import CONF_API_KEY, CONF_MODEL, CONF_VOICE, CONF_SPEED, DOMAIN, MODELS, VOICES | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
async def validate_input(user_input: dict): | ||
""" Function to validate provided data""" | ||
if len(user_input[CONF_API_KEY]) != 51: | ||
raise WrongAPIKey | ||
|
||
|
||
class OpenAITTSConfigFlow(ConfigFlow, domain=DOMAIN): | ||
"""Handle a config flow .""" | ||
|
||
VERSION = 1 | ||
|
||
async def async_step_user(self, user_input: dict[str, Any] | None = None): | ||
"""Handle the initial step.""" | ||
|
||
data_schema = {vol.Required(CONF_API_KEY): str, | ||
vol.Optional(CONF_SPEED, default=1): int, | ||
CONF_MODEL: selector({ | ||
"select": { | ||
"options": MODELS, | ||
"mode": "dropdown", | ||
"sort": True, | ||
"custom_value": False | ||
} | ||
}), CONF_VOICE: selector({ | ||
"select": { | ||
"options": VOICES, | ||
"mode": "dropdown", | ||
"sort": True, | ||
"custom_value": False | ||
} | ||
}) | ||
} | ||
|
||
errors = {} | ||
|
||
if user_input is not None: | ||
try: | ||
self._async_abort_entries_match({CONF_VOICE: user_input[CONF_VOICE]}) | ||
await validate_input(user_input) | ||
return self.async_create_entry(title="OpenAI TTS", data=user_input) | ||
except WrongAPIKey: | ||
_LOGGER.exception("Wrong or no API key provided.") | ||
errors[CONF_API_KEY] = "wrong_api_key" | ||
except Exception: # pylint: disable=broad-except | ||
_LOGGER.exception("Unknown exception.") | ||
errors["base"] = "Unknown exception." | ||
|
||
return self.async_show_form(step_id="user", data_schema=vol.Schema(data_schema)) | ||
|
||
|
||
class WrongAPIKey(HomeAssistantError): | ||
"""Error to indicate no or wrong API key.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
""" Constants for OpenAI TTS custom component""" | ||
|
||
DOMAIN = "openai_tts" | ||
CONF_API_KEY = 'api_key' | ||
CONF_MODEL = 'model' | ||
CONF_VOICE = 'voice' | ||
CONF_SPEED = 'speed' | ||
MODELS = ["tts-1", "tts-1-hd"] | ||
VOICES = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] | ||
URL = "https://api.openai.com/v1/audio/speech" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
{ | ||
"domain": "openai_tts", | ||
"name": "OpenAI TTS", | ||
"config_flow": true, | ||
"codeowners": ["@sfortis"], | ||
"dependencies": [], | ||
"documentation": "https://github.com/sfortis/openai_tts/", | ||
"iot_class": "cloud_polling", | ||
"issue_tracker": "https://github.com/sfortis/openai_tts/issues", | ||
"requirements": ["requests>=2.25.1"], | ||
"version": "0.1.0" | ||
"version": "0.2.0" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import requests | ||
|
||
from .const import URL | ||
|
||
|
||
class OpenAITTSEngine: | ||
|
||
def __init__(self, api_key: str, voice: str, model: str, speed: int): | ||
self._api_key = api_key | ||
self._voice = voice | ||
self._model = model | ||
self._speed = speed | ||
self._url = URL | ||
|
||
def get_tts(self, text: str): | ||
""" Makes request to OpenAI TTS engine to convert text into audio""" | ||
headers: dict = {"Authorization": f"Bearer {self._api_key}"} | ||
data: dict = {"model": self._model, "input": text, "voice": self._voice, "speed": self._speed} | ||
return requests.post(self._url, headers=headers, json=data) | ||
|
||
@staticmethod | ||
def get_supported_langs() -> list: | ||
"""Returns list of supported languages. Note: the model determines the provides language automatically.""" | ||
return ["af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"config": { | ||
"step": { | ||
"user": { | ||
"title": "Add text-to-speech engine", | ||
"description": "Provide configuration data. See documentation for further info.", | ||
"data": { | ||
"api_key": "Enter OpenAI API key.", | ||
"speed": "Enter speed of the speech", | ||
"model": "Select model to be used.", | ||
"voice": "Select voice." | ||
} | ||
} | ||
}, | ||
"error": { | ||
"wrong_api_key": "Connection was not authorized. Wrong or no API key provided." | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"config": { | ||
"step": { | ||
"user": { | ||
"title": "Přidej engine pro převod textu na řeč", | ||
"description": "Vlož konfigurační data. Pro detaily se podívej na dokumentaci", | ||
"data": { | ||
"api_key": "Vlož OpenAI API klíč.", | ||
"speed": "Vlož rychlost řeči.", | ||
"model": "Vyber model k použití.", | ||
"voice": "Vyber hlas." | ||
} | ||
} | ||
}, | ||
"error": { | ||
"wrong_api_key": "Nebyl poskytnut správný API klíč." | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"config": { | ||
"step": { | ||
"user": { | ||
"title": "Add text-to-speech engine", | ||
"description": "Provide configuration data. See documentation for further info.", | ||
"data": { | ||
"api_key": "Enter OpenAI API key.", | ||
"speed": "Enter speed of the speech", | ||
"model": "Select model to be used.", | ||
"voice": "Select voice." | ||
} | ||
} | ||
}, | ||
"error": { | ||
"wrong_api_key": "Connection was not authorized. Wrong or no API key provided." | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,79 @@ | ||
""" | ||
Support for OpenAI TTS. | ||
Setting up TTS entity. | ||
""" | ||
import logging | ||
import requests | ||
import voluptuous as vol | ||
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider | ||
from homeassistant.components.tts import TextToSpeechEntity | ||
from homeassistant.config_entries import ConfigEntry | ||
from homeassistant.core import HomeAssistant | ||
import homeassistant.helpers.config_validation as cv | ||
from homeassistant.helpers.entity_platform import AddEntitiesCallback | ||
from .const import CONF_API_KEY,CONF_MODEL, CONF_SPEED, CONF_VOICE, DOMAIN | ||
from .openaitts_engine import OpenAITTSEngine | ||
from homeassistant.exceptions import MaxLengthExceeded | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
CONF_API_KEY = 'api_key' | ||
DEFAULT_LANG = 'en-US' | ||
OPENAI_TTS_URL = "https://api.openai.com/v1/audio/speech" | ||
CONF_MODEL = 'model' | ||
CONF_VOICE = 'voice' | ||
CONF_SPEED = 'speed' | ||
|
||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ | ||
vol.Required(CONF_API_KEY): cv.string, | ||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.string, | ||
vol.Optional(CONF_MODEL, default='tts-1'): cv.string, | ||
vol.Optional(CONF_VOICE, default='shimmer'): cv.string, | ||
vol.Optional(CONF_SPEED, default=1): cv.string, | ||
}) | ||
|
||
def get_engine(hass, config, discovery_info=None): | ||
"""Set up OpenAI TTS speech component.""" | ||
api_key = config[CONF_API_KEY] | ||
language = config.get(CONF_LANG, DEFAULT_LANG) | ||
model = config.get(CONF_MODEL) | ||
voice = config.get(CONF_VOICE) | ||
speed = config.get(CONF_SPEED) | ||
return OpenAITTSProvider(hass, api_key, language, model, voice, speed) | ||
|
||
class OpenAITTSProvider(Provider): | ||
"""The OpenAI TTS API provider.""" | ||
|
||
def __init__(self, hass, api_key, lang, model, voice, speed): | ||
"""Initialize OpenAI TTS provider.""" | ||
|
||
async def async_setup_entry( | ||
hass: HomeAssistant, | ||
config_entry: ConfigEntry, | ||
async_add_entities: AddEntitiesCallback, | ||
) -> None: | ||
"""Set up OpenAI Text-to-speech platform via config entry.""" | ||
engine = OpenAITTSEngine( | ||
config_entry.data[CONF_API_KEY], | ||
config_entry.data[CONF_VOICE], | ||
config_entry.data[CONF_MODEL], | ||
config_entry.data[CONF_SPEED], | ||
) | ||
async_add_entities([OpenAITTSEntity(hass, config_entry, engine)]) | ||
|
||
|
||
class OpenAITTSEntity(TextToSpeechEntity): | ||
"""The OpenAI TTS entity.""" | ||
_attr_has_entity_name = True | ||
_attr_should_poll = False | ||
|
||
def __init__(self, hass, config, engine): | ||
"""Initialize TTS entity.""" | ||
self.hass = hass | ||
self._api_key = api_key | ||
self._language = lang | ||
self._model = model | ||
self._voice = voice | ||
self._speed = speed | ||
self._engine = engine | ||
self._config = config | ||
self._attr_unique_id = self._config.data[CONF_VOICE] | ||
|
||
@property | ||
def default_language(self): | ||
"""Return the default language.""" | ||
return self._language | ||
return "en" | ||
|
||
@property | ||
def supported_languages(self): | ||
"""Return the list of supported languages.""" | ||
# Ideally, this list should be dynamically fetched from OpenAI, if supported. | ||
return [self._language] | ||
return self._engine.get_supported_langs() | ||
|
||
@property | ||
def device_info(self): | ||
return {"identifiers": {(DOMAIN, self._attr_unique_id)}, "name": f"OpenAI {self._config.data[CONF_VOICE]}", "manufacturer": "OpenAI"} | ||
|
||
@property | ||
def name(self): | ||
"""Return name of entity""" | ||
return " engine" | ||
|
||
def get_tts_audio(self, message, language, options=None): | ||
"""Convert a given text to speech and return it as bytes.""" | ||
# Define the headers, including the Authorization header with your API key | ||
headers = { | ||
'Authorization': f'Bearer {self._api_key}' | ||
} | ||
|
||
# Define the data payload, specifying the model, input text, voice, and response format | ||
data = { | ||
'model': self._model, # Choose between 'tts-1' and 'tts-1-hd' based on your preference | ||
'voice': self._voice, # Choose the desired voice | ||
'speed': self._speed, # Voice speed | ||
'input': message, | ||
# Optional parameters can also be included, like 'speed' and 'response_format' | ||
} | ||
|
||
try: | ||
# Make the POST request to the correct endpoint for generating speech | ||
response = requests.post(OPENAI_TTS_URL, json=data, headers=headers) | ||
response.raise_for_status() # Raises an HTTPError if the HTTP request returned an unsuccessful status code | ||
if len(message) > 4096: | ||
raise MaxLengthExceeded | ||
|
||
speech = self._engine.get_tts(message) | ||
|
||
# The response should contain the audio file content | ||
return "mp3", response.content | ||
except requests.exceptions.HTTPError as http_err: | ||
_LOGGER.error("HTTP error from OpenAI: %s", http_err) | ||
except requests.exceptions.RequestException as req_err: | ||
_LOGGER.error("Request exception from OpenAI: %s", req_err) | ||
return "mp3", speech.content | ||
except Exception as e: | ||
_LOGGER.error("Unknown Error: %s", e) | ||
|
||
except MaxLengthExceeded: | ||
_LOGGER.error("Maximum length of the message exceeded") | ||
|
||
return None, None |