Skip to content

Commit

Permalink
Fix VAD issue
Browse files Browse the repository at this point in the history
  • Loading branch information
thainguyensunya committed Dec 15, 2024
1 parent 330617a commit cdeba6d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
7 changes: 3 additions & 4 deletions backend/modal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
from fastapi import FastAPI, UploadFile, File, Form

from speech_profile_modal import ResponseItem, endpoint as speaker_identification_endpoint
from vad_modal import endpoint as vad_endpoint
from vad_modal import vad_endpoint

app = FastAPI()


@app.post('/v1/speaker-identification')
def speaker_identification(
uid: str, audio_file: UploadFile = File, segments: str = Form(...)
) -> List[ResponseItem]:
print('speaker_identification')
return speaker_identification_endpoint(uid, audio_file, segments)


@app.post('/v1/vad')
def vad(audio_file: UploadFile = File(...)):
def vad(audio_file: UploadFile = File):
print('vad')
print(vad_endpoint)
return vad_endpoint(audio_file)
27 changes: 9 additions & 18 deletions backend/modal/vad_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
from fastapi import UploadFile
from modal import App, web_endpoint, Secret, Image
from pyannote.audio import Pipeline

# Instantiate pretrained voice activity detection pipeline
Expand All @@ -13,26 +12,18 @@
use_auth_token=os.getenv('HUGGINGFACE_TOKEN')
).to(device)

app = App(name='vad')
image = (
Image.debian_slim()
.pip_install("pyannote.audio")
.pip_install("torch")
.pip_install("torchaudio")
)
# app = App(name='vad')
# image = (
# Image.debian_slim()
# .pip_install("pyannote.audio")
# .pip_install("torch")
# .pip_install("torchaudio")
# )

os.makedirs('_temp', exist_ok=True)


# @app.function(
# image=image,
# keep_warm=1,
# memory=(1024, 2048),
# cpu=4,
# secrets=[Secret.from_name('huggingface-token')],
# )
# @web_endpoint(method='POST')
def endpoint(file: UploadFile):
def vad_endpoint(file: UploadFile):
upload_id = str(uuid.uuid4())
file_path = f"_temp/{upload_id}_{file.filename}"
with open(file_path, 'wb') as f:
Expand All @@ -47,4 +38,4 @@ def endpoint(file: UploadFile):
'end': segment.end,
'duration': segment.duration,
})
return data
return data

0 comments on commit cdeba6d

Please sign in to comment.