Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force diarizer to use CUDA if cuda is available and if device=None. #9380

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions nemo/collections/asr/models/clustering_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def get_available_model_names(class_name):

class ClusteringDiarizer(torch.nn.Module, Model, DiarizationMixin):
"""
Inference model Class for offline speaker diarization.
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
Extract Embeddings, Clustering, Resegmentation and Scoring.
All the parameters are passed through config file
Inference model Class for offline speaker diarization.
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
Extract Embeddings, Clustering, Resegmentation and Scoring.
All the parameters are passed through config file
"""

def __init__(self, cfg: Union[DictConfig, Any], speaker_model=None):
Expand Down Expand Up @@ -137,7 +137,10 @@ def _init_speaker_model(self, speaker_model=None):
Initialize speaker embedding model with model name or path passed through config
"""
if speaker_model is not None:
self._speaker_model = speaker_model
if self._cfg.device is None and torch.cuda.is_available():
self._speaker_model = speaker_model.to(torch.device('cuda'))
else:
self._speaker_model = speaker_model
else:
model_path = self._cfg.diarizer.speaker_embeddings.model_path
if model_path is not None and model_path.endswith('.nemo'):
Expand All @@ -158,7 +161,6 @@ def _init_speaker_model(self, speaker_model=None):
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(
model_name=model_path, map_location=self._cfg.device
)

self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec,
Expand All @@ -171,7 +173,9 @@ def _setup_vad_test_data(self, manifest_vad_input):
'sample_rate': self._cfg.sample_rate,
'batch_size': self._cfg.get('batch_size'),
'vad_stream': True,
'labels': ['infer',],
'labels': [
'infer',
],
'window_length_in_sec': self._vad_window_length_in_sec,
'shift_length_in_sec': self._vad_shift_length_in_sec,
'trim_silence': False,
Expand All @@ -192,8 +196,8 @@ def _setup_spkr_test_data(self, manifest_file):

def _run_vad(self, manifest_file):
"""
Run voice activity detection.
Get log probability of voice activity detection and smoothes using the post processing parameters.
Run voice activity detection.
Get log probability of voice activity detection and smoothes using the post processing parameters.
Using generated frame level predictions generated manifest file for later speaker embedding extraction.
input:
manifest_file (str) : Manifest file containing path to audio file and label as infer
Expand Down Expand Up @@ -338,7 +342,7 @@ def _perform_speech_activity_detection(self):
def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: int):
"""
This method extracts speaker embeddings from segments passed through manifest_file
Optionally you may save the intermediate speaker embeddings for debugging or any use.
Optionally you may save the intermediate speaker embeddings for debugging or any use.
"""
logging.info("Extracting embeddings for Diarization")
self._setup_spkr_test_data(manifest_file)
Expand Down
Loading