diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 533f276c0018..93913a43c1b5 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -74,10 +74,10 @@ def get_available_model_names(class_name): class ClusteringDiarizer(torch.nn.Module, Model, DiarizationMixin): """ - Inference model Class for offline speaker diarization. - This class handles required functionality for diarization : Speech Activity Detection, Segmentation, - Extract Embeddings, Clustering, Resegmentation and Scoring. - All the parameters are passed through config file + Inference model Class for offline speaker diarization. + This class handles required functionality for diarization : Speech Activity Detection, Segmentation, + Extract Embeddings, Clustering, Resegmentation and Scoring. + All the parameters are passed through config file """ def __init__(self, cfg: Union[DictConfig, Any], speaker_model=None): @@ -137,7 +137,10 @@ def _init_speaker_model(self, speaker_model=None): Initialize speaker embedding model with model name or path passed through config """ if speaker_model is not None: - self._speaker_model = speaker_model + if self._cfg.device is None and torch.cuda.is_available(): + self._speaker_model = speaker_model.to(torch.device('cuda')) + else: + self._speaker_model = speaker_model else: model_path = self._cfg.diarizer.speaker_embeddings.model_path if model_path is not None and model_path.endswith('.nemo'): @@ -158,7 +161,6 @@ def _init_speaker_model(self, speaker_model=None): self._speaker_model = EncDecSpeakerLabelModel.from_pretrained( model_name=model_path, map_location=self._cfg.device ) - self.multiscale_args_dict = parse_scale_configs( self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec, self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec, @@ -171,7 +173,9 @@ def _setup_vad_test_data(self, manifest_vad_input): 'sample_rate': self._cfg.sample_rate, 'batch_size': self._cfg.get('batch_size'), 'vad_stream': True, - 'labels': ['infer',], + 'labels': [ + 'infer', + ], 'window_length_in_sec': self._vad_window_length_in_sec, 'shift_length_in_sec': self._vad_shift_length_in_sec, 'trim_silence': False, @@ -192,8 +196,8 @@ def _setup_spkr_test_data(self, manifest_file): def _run_vad(self, manifest_file): """ - Run voice activity detection. - Get log probability of voice activity detection and smoothes using the post processing parameters. + Run voice activity detection. + Get log probability of voice activity detection and smoothes using the post processing parameters. Using generated frame level predictions generated manifest file for later speaker embedding extraction. input: manifest_file (str) : Manifest file containing path to audio file and label as infer @@ -338,7 +342,7 @@ def _perform_speech_activity_detection(self): def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: int): """ This method extracts speaker embeddings from segments passed through manifest_file - Optionally you may save the intermediate speaker embeddings for debugging or any use. + Optionally you may save the intermediate speaker embeddings for debugging or any use. """ logging.info("Extracting embeddings for Diarization") self._setup_spkr_test_data(manifest_file)