diff --git a/ayushma/migrations/0051_project_tts_engine.py b/ayushma/migrations/0051_project_tts_engine.py new file mode 100644 index 00000000..5ddc3372 --- /dev/null +++ b/ayushma/migrations/0051_project_tts_engine.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.6 on 2024-02-11 15:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ayushma", "0050_alter_chat_model_alter_project_model"), + ] + + operations = [ + migrations.AddField( + model_name="project", + name="tts_engine", + field=models.SmallIntegerField( + choices=[(1, "openai"), (2, "google")], default=2 + ), + ), + ] diff --git a/ayushma/models/enums.py b/ayushma/models/enums.py index 56611c53..a7a7c9cb 100644 --- a/ayushma/models/enums.py +++ b/ayushma/models/enums.py @@ -19,6 +19,11 @@ class STTEngine(IntegerChoices): SELF_HOSTED = 3 +class TTSEngine(IntegerChoices): + OPENAI = (1, "openai") + GOOGLE = (2, "google") + + class FeedBackRating(IntegerChoices): HALLUCINATING = 1 WRONG = 2 diff --git a/ayushma/models/project.py b/ayushma/models/project.py index e2e53531..5ebcc76d 100644 --- a/ayushma/models/project.py +++ b/ayushma/models/project.py @@ -1,7 +1,7 @@ from django.contrib.postgres.fields import ArrayField from django.db import models -from ayushma.models.enums import ModelType, STTEngine +from ayushma.models.enums import ModelType, STTEngine, TTSEngine from ayushma.models.users import User from utils.models.base import BaseModel @@ -16,6 +16,9 @@ class Project(BaseModel): stt_engine = models.IntegerField( choices=STTEngine.choices, default=STTEngine.WHISPER ) + tts_engine = models.SmallIntegerField( + choices=TTSEngine.choices, default=TTSEngine.GOOGLE + ) model = models.IntegerField(choices=ModelType.choices, default=ModelType.GPT_3_5) preset_questions = ArrayField(models.TextField(), null=True, blank=True) is_default = models.BooleanField(default=False) diff --git a/ayushma/serializers/project.py b/ayushma/serializers/project.py index 4a7d0ddf..5681b43a 100644 --- a/ayushma/serializers/project.py +++ b/ayushma/serializers/project.py @@ -25,6 +25,7 @@ class Meta: "modified_at", "description", "stt_engine", + "tts_engine", "model", "is_default", "display_preset_questions", diff --git a/ayushma/utils/language_helpers.py b/ayushma/utils/language_helpers.py index c67e6161..460e2093 100644 --- a/ayushma/utils/language_helpers.py +++ b/ayushma/utils/language_helpers.py @@ -1,9 +1,13 @@ import re +from django.conf import settings from google.cloud import texttospeech from google.cloud import translate_v2 as translate +from openai import OpenAI from rest_framework.exceptions import APIException +from ayushma.models.enums import TTSEngine + def translate_text(target, text): try: @@ -37,31 +41,43 @@ def sanitize_text(text): return sanitized_text -def text_to_speech(text, language_code): +def text_to_speech(text, language_code, service): try: # in en-IN neural voice is not available if language_code == "en-IN": language_code = "en-US" - client = texttospeech.TextToSpeechClient() - text = sanitize_text(text) - synthesis_input = texttospeech.SynthesisInput(text=text) - - voice = texttospeech.VoiceSelectionParams( - language_code=language_code, name=language_code_voice_map[language_code] - ) - audio_config = texttospeech.AudioConfig( - audio_encoding=texttospeech.AudioEncoding.MP3 - ) - - response = client.synthesize_speech( - input=synthesis_input, - voice=voice, - audio_config=audio_config, - ) - - return response.audio_content + + if service == TTSEngine.GOOGLE: + client = texttospeech.TextToSpeechClient() + + synthesis_input = texttospeech.SynthesisInput(text=text) + + voice = texttospeech.VoiceSelectionParams( + language_code=language_code, name=language_code_voice_map[language_code] + ) + audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.MP3 + ) + + response = client.synthesize_speech( + input=synthesis_input, + voice=voice, + audio_config=audio_config, + ) + + return response.audio_content + elif service == TTSEngine.OPENAI: + client = OpenAI(api_key=settings.OPENAI_API_KEY) + response = client.audio.speech.create( + model="tts-1-hd", + voice="nova", + input=text, + ) + return response.read() + else: + raise APIException("Service not supported") except Exception as e: print(e) return None diff --git a/ayushma/utils/openaiapi.py b/ayushma/utils/openaiapi.py index 10f7956c..2326d5ff 100644 --- a/ayushma/utils/openaiapi.py +++ b/ayushma/utils/openaiapi.py @@ -203,6 +203,7 @@ def handle_post_response( temperature, stats, language, + tts_engine, generate_audio=True, ): chat_message: ChatMessage = ChatMessage.objects.create( @@ -225,7 +226,9 @@ def handle_post_response( ayushma_voice = None if generate_audio: stats["tts_start_time"] = time.time() - ayushma_voice = text_to_speech(translated_chat_response, user_language) + ayushma_voice = text_to_speech( + translated_chat_response, user_language, tts_engine + ) stats["tts_end_time"] = time.time() url = None @@ -324,6 +327,8 @@ def converse( elif message.messageType == ChatMessageType.AYUSHMA: chat_history.append(AIMessage(content=f"Ayushma: {message.message}")) + tts_engine = chat.project.tts_engine + if not stream: lang_chain_helper = LangChainHelper( stream=False, @@ -347,6 +352,7 @@ def converse( temperature, stats, language, + tts_engine, generate_audio, ) @@ -404,6 +410,7 @@ def converse( temperature, stats, language, + tts_engine, generate_audio, ) diff --git a/utils/pagination.py b/utils/pagination.py index 05509fe4..c54def1c 100644 --- a/utils/pagination.py +++ b/utils/pagination.py @@ -13,5 +13,6 @@ def get_paginated_response(self, data): "has_previous": self.offset > 0, "has_next": self.offset + self.limit < self.count, "results": data, + "offset": self.offset, } )