diff --git a/kimm/_src/utils/model_registry.py b/kimm/_src/utils/model_registry.py index 9b70302..e98a9c8 100644 --- a/kimm/_src/utils/model_registry.py +++ b/kimm/_src/utils/model_registry.py @@ -78,8 +78,23 @@ def list_models( name: typing.Optional[str] = None, feature_extractor: typing.Optional[bool] = None, weights: typing.Optional[typing.Union[bool, str]] = None, -) -> typing.List[str]: - result_names: typing.Set = set() +): + """List the models with the given arguments. + + Args: + name: An optional `str` specifying the substring of the name of the + model to seatch for. If not specified, all models will be included. + feature_extractor: Whether to include models that support + feature extraction. Defaults to `None`, which means this + argument is not considered. + weights: An optional boolean or `str` specifying the name of the + pretrained weights. The available values are (`"imagenet"`). + Defaults to `None`, which means this argument is not considered. + + Returns: + A list of model names. + """ + result_names: typing.Set[str] = set() for info in MODEL_REGISTRY: # Add by default result_names.add(info["name"]) diff --git a/kimm/_src/utils/model_utils.py b/kimm/_src/utils/model_utils.py index 32b9831..d144f6e 100644 --- a/kimm/_src/utils/model_utils.py +++ b/kimm/_src/utils/model_utils.py @@ -4,6 +4,17 @@ @kimm_export(parent_path=["kimm.utils"]) def get_reparameterized_model(model: BaseModel): + """Get the reparameterized model. + + Internally, this function calls `get_reparameterized_model` from the + provided `model`. + + Args: + model: A `BaseModel` to convert to its reparameterized form. + + Returns: + An instance of the same class as `model` in its reparameterized form. + """ if not hasattr(model, "get_reparameterized_model"): raise ValueError( "There is no 'get_reparameterized_model' method in the model. "