Skip to content

Commit

Permalink
Adding NGC checkpoints for SE and SR
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Nov 28, 2024
1 parent 18fe970 commit 1c7b328
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions nemo/collections/audio/models/enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,24 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =

return {f'{tag}_loss': loss}

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
model = PretrainedModelInfo(
pretrained_model_name="sr_ssl_flowmatching_16k_430m",
description="For details on this model, please refer to https://ngc.nvidia.com/catalog/models/nvidia:nemo:sr_ssl_flowmatching_16k_430m",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/sr_ssl_flowmatching_16k_430m/versions/v1/files/sr_ssl_flowmatching_16k_430m.nemo",
)
results.append(model)

return results


class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel):
"""This models is using a Schrödinger Bridge process to generate
Expand Down Expand Up @@ -1235,3 +1253,28 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

return {f'{tag}_loss': loss}

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
model = PretrainedModelInfo(
pretrained_model_name="se_den_sb_16k_small",
description="For details on this model, please refer to https://ngc.nvidia.com/catalog/models/nvidia:nemo:se_den_sb_16k_small",
location="https://api.ngc.nvidia.com/v2/org/nvidia/team/nemo/models/se_den_sb_16k_small/versions/v1.0/files/se_den_sb_16k_small.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="se_der_sb_16k_small",
description="For details on this model, please refer to https://ngc.nvidia.com/catalog/models/nvidia:nemo:se_der_sb_16k_small",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/se_der_sb_16k_small/versions/v1/files/se_der_sb_16k_small.nemo",
)
results.append(model)
return results

0 comments on commit 1c7b328

Please sign in to comment.