From f02404080869243ec578788ae5aa710eea1c3bb8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 24 Jul 2024 08:11:27 -0400 Subject: [PATCH] Fix hf hub for 0.24+ (#9806) (#9857) * Update Huggingface Hub support * Update hf hub * Update hf hub * Apply isort and black reformatting --------- Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Signed-off-by: titu1994 Co-authored-by: Somshubra Majumdar Signed-off-by: Vivian Chen --- nemo/core/classes/mixins/hf_io_mixin.py | 88 +++++++++---------------- requirements/requirements.txt | 2 +- tests/core/test_save_restore.py | 16 +++-- 3 files changed, 42 insertions(+), 64 deletions(-) diff --git a/nemo/core/classes/mixins/hf_io_mixin.py b/nemo/core/classes/mixins/hf_io_mixin.py index b101cbabe749..543d6c6fccda 100644 --- a/nemo/core/classes/mixins/hf_io_mixin.py +++ b/nemo/core/classes/mixins/hf_io_mixin.py @@ -14,9 +14,9 @@ from abc import ABC from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union -from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter +from huggingface_hub import HfApi, ModelCard, ModelCardData from huggingface_hub import get_token as get_hf_token from huggingface_hub.hf_api import ModelInfo from huggingface_hub.utils import SoftTemporaryDirectory @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC): """ @classmethod - def get_hf_model_filter(cls) -> ModelFilter: + def get_hf_model_filter(cls) -> Dict[str, Any]: """ Generates a filter for HuggingFace models. - Additionally includes default values of some metadata about results returned by the Hub. + Additionaly includes default values of some metadata about results returned by the Hub. Metadata: resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False. limit_results: Optional int, limits the number of results returned. Returns: - A Hugging Face Hub ModelFilter object. + A dict representing the arguments passable to huggingface list_models(). """ - model_filter = ModelFilter(library='nemo') - - # Attach some additional info - model_filter.resolve_card_info = False - model_filter.limit_results = None + model_filter = dict( + author=None, + library='nemo', + language=None, + model_name=None, + task=None, + tags=None, + limit=None, + full=None, + cardData=False, + ) return model_filter @classmethod - def search_huggingface_models( - cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None - ) -> List['ModelInfo']: + def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']: """ Should list all pre-trained models available via Hugging Face Hub. @@ -75,16 +79,16 @@ def search_huggingface_models( # You can replace with any subclass of ModelPT. from nemo.core import ModelPT - # Get default ModelFilter + # Get default filter dict filt = .get_hf_model_filter() # Make any modifications to the filter as necessary - filt.language = [...] - filt.task = ... - filt.tags = [...] + filt['language'] = [...] + filt['task'] = ... + filt['tags'] = [...] - # Add any metadata to the filter as needed - filt.limit_results = 5 + # Add any metadata to the filter as needed (kwargs to list_models) + filt['limit'] = 5 # Obtain model info model_infos = .search_huggingface_models(model_filter=filt) @@ -96,10 +100,9 @@ def search_huggingface_models( model = ModelPT.from_pretrained(card.modelId) Args: - model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub) + model_filter: Optional Dictionary (for Hugging Face Hub kwargs) that filters the returned list of compatible model cards, and selects all results from each filter. Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model. - If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`. Returns: A list of ModelInfo entries. @@ -108,23 +111,6 @@ def search_huggingface_models( if model_filter is None: model_filter = cls.get_hf_model_filter() - # If single model filter, wrap into list - if not isinstance(model_filter, Iterable): - model_filter = [model_filter] - - # Inject `nemo` library filter - for mfilter in model_filter: - if isinstance(mfilter.library, str) and mfilter.library != 'nemo': - logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}") - mfilter.library = "nemo" - - elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library: - logging.warning( - f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}" - ) - mfilter.library = list(mfilter) - mfilter.library.append('nemo') - # Check if api token exists, use if it does hf_token = get_hf_token() @@ -134,24 +120,11 @@ def search_huggingface_models( # Setup extra arguments for model filtering all_results = [] # type: List[ModelInfo] - for mfilter in model_filter: - cardData = None - limit = None - - if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True: - cardData = True - - if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None: - limit = mfilter.limit_results - - results = api.list_models( - filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit, - ) # type: Iterable[ModelInfo] - - for result in results: - all_results.append(result) + results = api.list_models( + token=hf_token, sort="lastModified", direction=-1, **model_filter + ) # type: Iterable[ModelInfo] - return all_results + return results def push_to_hf_hub( self, @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str, A HuggingFace ModelCard object that can be converted to a model card string. """ card_data = ModelCardData( - library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True, + library_name='nemo', + tags=['pytorch', 'NeMo'], + license='cc-by-4.0', + ignore_metadata_errors=True, ) if 'card_data' not in template_kwargs: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7706aa58b267..3169d31dbeed 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ fiddle -huggingface_hub>=0.20.3,<0.24.0 +huggingface_hub>=0.24 numba numpy>=1.22 onnx>=1.7.0 diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 57cbe94b60d7..394ced55a452 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -19,7 +19,6 @@ import pytest import torch -from huggingface_hub.hf_api import ModelFilter from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE @@ -1324,8 +1323,8 @@ class MockModelV2(MockModel): @pytest.mark.unit def test_hf_model_filter(self): filt = ModelPT.get_hf_model_filter() - assert isinstance(filt, ModelFilter) - assert filt.library == 'nemo' + assert isinstance(filt, dict) + assert filt['library'] == 'nemo' @pytest.mark.with_downloads() @pytest.mark.unit @@ -1334,10 +1333,12 @@ def test_hf_model_info(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=None) + model_infos = [next(model_infos) for _ in range(5)] assert len(model_infos) > 0 # check with default override results (should match above) default_model_infos = ModelPT.search_huggingface_models(model_filter=filt) + default_model_infos = [next(default_model_infos) for _ in range(5)] assert len(model_infos) == len(default_model_infos) @pytest.mark.pleasefixme() @@ -1348,13 +1349,12 @@ def test_hf_model_info_with_card_data(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=filt) + model_infos = [next(model_infos) for _ in range(5)] assert len(model_infos) > 0 - assert not hasattr(model_infos[0], 'cardData') # check overriden defaults - filt.resolve_card_info = True + filt['cardData'] = True model_infos = ModelPT.search_huggingface_models(model_filter=filt) - assert len(model_infos) > 0 for info in model_infos: if hasattr(info, 'cardData'): @@ -1368,11 +1368,13 @@ def test_hf_model_info_with_limited_results(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=filt) + model_infos = [next(model_infos) for _ in range(6)] assert len(model_infos) > 0 # check overriden defaults - filt.limit_results = 5 + filt['limit'] = 5 new_model_infos = ModelPT.search_huggingface_models(model_filter=filt) + new_model_infos = list(new_model_infos) assert len(new_model_infos) <= 5 assert len(new_model_infos) < len(model_infos)