Skip to content

Commit

Permalink
Fix hf hub for 0.24+ (NVIDIA#9806) (NVIDIA#9857)
Browse files Browse the repository at this point in the history
* Update Huggingface Hub support

* Update hf hub

* Update hf hub

* Apply isort and black reformatting

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Somshubra Majumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: Boxiang Wang <[email protected]>
  • Loading branch information
2 people authored and BoxiangW committed Jul 30, 2024
1 parent 3c233fc commit 2c589b2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 64 deletions.
88 changes: 32 additions & 56 deletions nemo/core/classes/mixins/hf_io_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from abc import ABC
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter
from huggingface_hub import HfApi, ModelCard, ModelCardData
from huggingface_hub import get_token as get_hf_token
from huggingface_hub.hf_api import ModelInfo
from huggingface_hub.utils import SoftTemporaryDirectory
Expand All @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC):
"""

@classmethod
def get_hf_model_filter(cls) -> ModelFilter:
def get_hf_model_filter(cls) -> Dict[str, Any]:
"""
Generates a filter for HuggingFace models.
Additionally includes default values of some metadata about results returned by the Hub.
Additionaly includes default values of some metadata about results returned by the Hub.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
Returns:
A Hugging Face Hub ModelFilter object.
A dict representing the arguments passable to huggingface list_models().
"""
model_filter = ModelFilter(library='nemo')

# Attach some additional info
model_filter.resolve_card_info = False
model_filter.limit_results = None
model_filter = dict(
author=None,
library='nemo',
language=None,
model_name=None,
task=None,
tags=None,
limit=None,
full=None,
cardData=False,
)

return model_filter

@classmethod
def search_huggingface_models(
cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None
) -> List['ModelInfo']:
def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']:
"""
Should list all pre-trained models available via Hugging Face Hub.
Expand All @@ -75,16 +79,16 @@ def search_huggingface_models(
# You can replace <DomainSubclass> with any subclass of ModelPT.
from nemo.core import ModelPT
# Get default ModelFilter
# Get default filter dict
filt = <DomainSubclass>.get_hf_model_filter()
# Make any modifications to the filter as necessary
filt.language = [...]
filt.task = ...
filt.tags = [...]
filt['language'] = [...]
filt['task'] = ...
filt['tags'] = [...]
# Add any metadata to the filter as needed
filt.limit_results = 5
# Add any metadata to the filter as needed (kwargs to list_models)
filt['limit'] = 5
# Obtain model info
model_infos = <DomainSubclass>.search_huggingface_models(model_filter=filt)
Expand All @@ -96,10 +100,9 @@ def search_huggingface_models(
model = ModelPT.from_pretrained(card.modelId)
Args:
model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub)
model_filter: Optional Dictionary (for Hugging Face Hub kwargs)
that filters the returned list of compatible model cards, and selects all results from each filter.
Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model.
If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`.
Returns:
A list of ModelInfo entries.
Expand All @@ -108,23 +111,6 @@ def search_huggingface_models(
if model_filter is None:
model_filter = cls.get_hf_model_filter()

# If single model filter, wrap into list
if not isinstance(model_filter, Iterable):
model_filter = [model_filter]

# Inject `nemo` library filter
for mfilter in model_filter:
if isinstance(mfilter.library, str) and mfilter.library != 'nemo':
logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}")
mfilter.library = "nemo"

elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library:
logging.warning(
f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}"
)
mfilter.library = list(mfilter)
mfilter.library.append('nemo')

# Check if api token exists, use if it does
hf_token = get_hf_token()

Expand All @@ -134,24 +120,11 @@ def search_huggingface_models(
# Setup extra arguments for model filtering
all_results = [] # type: List[ModelInfo]

for mfilter in model_filter:
cardData = None
limit = None

if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True:
cardData = True

if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None:
limit = mfilter.limit_results

results = api.list_models(
filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit,
) # type: Iterable[ModelInfo]

for result in results:
all_results.append(result)
results = api.list_models(
token=hf_token, sort="lastModified", direction=-1, **model_filter
) # type: Iterable[ModelInfo]

return all_results
return results

def push_to_hf_hub(
self,
Expand Down Expand Up @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str,
A HuggingFace ModelCard object that can be converted to a model card string.
"""
card_data = ModelCardData(
library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True,
library_name='nemo',
tags=['pytorch', 'NeMo'],
license='cc-by-4.0',
ignore_metadata_errors=True,
)

if 'card_data' not in template_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fiddle
huggingface_hub>=0.20.3,<0.24.0
huggingface_hub>=0.24
numba
numpy>=1.22
onnx>=1.7.0
Expand Down
16 changes: 9 additions & 7 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import pytest
import torch
from huggingface_hub.hf_api import ModelFilter
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE
Expand Down Expand Up @@ -1324,8 +1323,8 @@ class MockModelV2(MockModel):
@pytest.mark.unit
def test_hf_model_filter(self):
filt = ModelPT.get_hf_model_filter()
assert isinstance(filt, ModelFilter)
assert filt.library == 'nemo'
assert isinstance(filt, dict)
assert filt['library'] == 'nemo'

@pytest.mark.with_downloads()
@pytest.mark.unit
Expand All @@ -1334,10 +1333,12 @@ def test_hf_model_info(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=None)
model_infos = [next(model_infos) for _ in range(5)]
assert len(model_infos) > 0

# check with default override results (should match above)
default_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
default_model_infos = [next(default_model_infos) for _ in range(5)]
assert len(model_infos) == len(default_model_infos)

@pytest.mark.pleasefixme()
Expand All @@ -1348,13 +1349,12 @@ def test_hf_model_info_with_card_data(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
model_infos = [next(model_infos) for _ in range(5)]
assert len(model_infos) > 0
assert not hasattr(model_infos[0], 'cardData')

# check overriden defaults
filt.resolve_card_info = True
filt['cardData'] = True
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0

for info in model_infos:
if hasattr(info, 'cardData'):
Expand All @@ -1368,11 +1368,13 @@ def test_hf_model_info_with_limited_results(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
model_infos = [next(model_infos) for _ in range(6)]
assert len(model_infos) > 0

# check overriden defaults
filt.limit_results = 5
filt['limit'] = 5
new_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
new_model_infos = list(new_model_infos)
assert len(new_model_infos) <= 5
assert len(new_model_infos) < len(model_infos)

Expand Down

0 comments on commit 2c589b2

Please sign in to comment.