Skip to content

Commit

Permalink
Fix hf hub for 0.24+ (#9806)
Browse files Browse the repository at this point in the history
* Update Huggingface Hub support

Signed-off-by: smajumdar <[email protected]>

* Update hf hub

Signed-off-by: smajumdar <[email protected]>

* Update hf hub

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Somshubra Majumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
  • Loading branch information
titu1994 authored Jul 23, 2024
1 parent b901138 commit 6ff5bce
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 83 deletions.
88 changes: 32 additions & 56 deletions nemo/core/classes/mixins/hf_io_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from abc import ABC
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter
from huggingface_hub import HfApi, ModelCard, ModelCardData
from huggingface_hub import get_token as get_hf_token
from huggingface_hub.hf_api import ModelInfo
from huggingface_hub.utils import SoftTemporaryDirectory
Expand All @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC):
"""

@classmethod
def get_hf_model_filter(cls) -> ModelFilter:
def get_hf_model_filter(cls) -> Dict[str, Any]:
"""
Generates a filter for HuggingFace models.
Additionally includes default values of some metadata about results returned by the Hub.
Additionaly includes default values of some metadata about results returned by the Hub.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
Returns:
A Hugging Face Hub ModelFilter object.
A dict representing the arguments passable to huggingface list_models().
"""
model_filter = ModelFilter(library='nemo')

# Attach some additional info
model_filter.resolve_card_info = False
model_filter.limit_results = None
model_filter = dict(
author=None,
library='nemo',
language=None,
model_name=None,
task=None,
tags=None,
limit=None,
full=None,
cardData=False,
)

return model_filter

@classmethod
def search_huggingface_models(
cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None
) -> List['ModelInfo']:
def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']:
"""
Should list all pre-trained models available via Hugging Face Hub.
Expand All @@ -75,16 +79,16 @@ def search_huggingface_models(
# You can replace <DomainSubclass> with any subclass of ModelPT.
from nemo.core import ModelPT
# Get default ModelFilter
# Get default filter dict
filt = <DomainSubclass>.get_hf_model_filter()
# Make any modifications to the filter as necessary
filt.language = [...]
filt.task = ...
filt.tags = [...]
filt['language'] = [...]
filt['task'] = ...
filt['tags'] = [...]
# Add any metadata to the filter as needed
filt.limit_results = 5
# Add any metadata to the filter as needed (kwargs to list_models)
filt['limit'] = 5
# Obtain model info
model_infos = <DomainSubclass>.search_huggingface_models(model_filter=filt)
Expand All @@ -96,10 +100,9 @@ def search_huggingface_models(
model = ModelPT.from_pretrained(card.modelId)
Args:
model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub)
model_filter: Optional Dictionary (for Hugging Face Hub kwargs)
that filters the returned list of compatible model cards, and selects all results from each filter.
Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model.
If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`.
Returns:
A list of ModelInfo entries.
Expand All @@ -108,23 +111,6 @@ def search_huggingface_models(
if model_filter is None:
model_filter = cls.get_hf_model_filter()

# If single model filter, wrap into list
if not isinstance(model_filter, Iterable):
model_filter = [model_filter]

# Inject `nemo` library filter
for mfilter in model_filter:
if isinstance(mfilter.library, str) and mfilter.library != 'nemo':
logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}")
mfilter.library = "nemo"

elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library:
logging.warning(
f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}"
)
mfilter.library = list(mfilter)
mfilter.library.append('nemo')

# Check if api token exists, use if it does
hf_token = get_hf_token()

Expand All @@ -134,24 +120,11 @@ def search_huggingface_models(
# Setup extra arguments for model filtering
all_results = [] # type: List[ModelInfo]

for mfilter in model_filter:
cardData = None
limit = None

if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True:
cardData = True

if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None:
limit = mfilter.limit_results

results = api.list_models(
filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit,
) # type: Iterable[ModelInfo]

for result in results:
all_results.append(result)
results = api.list_models(
token=hf_token, sort="lastModified", direction=-1, **model_filter
) # type: Iterable[ModelInfo]

return all_results
return results

def push_to_hf_hub(
self,
Expand Down Expand Up @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str,
A HuggingFace ModelCard object that can be converted to a model card string.
"""
card_data = ModelCardData(
library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True,
library_name='nemo',
tags=['pytorch', 'NeMo'],
license='cc-by-4.0',
ignore_metadata_errors=True,
)

if 'card_data' not in template_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fiddle
huggingface_hub>=0.20.3,<0.24.0
huggingface_hub>=0.24
numba
numpy>=1.22
onnx>=1.7.0
Expand Down
73 changes: 47 additions & 26 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import pytest
import torch
from huggingface_hub.hf_api import ModelFilter
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE
Expand Down Expand Up @@ -126,11 +125,15 @@ def __init__(self, cfg, trainer=None):
self.child1_model: Optional[MockModel] # annotate type for IDE autocompletion and type checking
if cfg.get("child1_model") is not None:
self.register_nemo_submodule(
"child1_model", config_field="child1_model", model=MockModel(self.cfg.child1_model),
"child1_model",
config_field="child1_model",
model=MockModel(self.cfg.child1_model),
)
elif cfg.get("child1_model_path") is not None:
self.register_nemo_submodule(
"child1_model", config_field="child1_model", model=MockModel.restore_from(self.cfg.child1_model_path),
"child1_model",
config_field="child1_model",
model=MockModel.restore_from(self.cfg.child1_model_path),
)
else:
self.child1_model = None
Expand All @@ -140,7 +143,9 @@ def __init__(self, cfg, trainer=None):
self.child2_model: Optional[MockModelWithChildren] # annotate type for IDE autocompletion and type checking
if cfg.get("child2_model") is not None:
self.register_nemo_submodule(
"child2_model", config_field="child2_model", model=MockModelWithChildren(self.cfg.child2_model),
"child2_model",
config_field="child2_model",
model=MockModelWithChildren(self.cfg.child2_model),
)
elif cfg.get("child2_model_path") is not None:
self.register_nemo_submodule(
Expand Down Expand Up @@ -169,7 +174,9 @@ def __init__(self, cfg, trainer=None):

if cfg.get("ctc_model", None) is not None:
self.register_nemo_submodule(
"ctc_model", config_field="ctc_model", model=EncDecCTCModelBPE(self.cfg.ctc_model),
"ctc_model",
config_field="ctc_model",
model=EncDecCTCModelBPE(self.cfg.ctc_model),
)
else:
# model is mandatory
Expand All @@ -196,7 +203,9 @@ def __init__(self, cfg, trainer=None):
self.child1_model: Optional[MockModel] # annotate type for IDE autocompletion and type checking
if cfg.get("child1_model_config") is not None:
self.register_nemo_submodule(
"child1_model", config_field="child1_model_config", model=MockModel(self.cfg.child1_model_config),
"child1_model",
config_field="child1_model_config",
model=MockModel(self.cfg.child1_model_config),
)
else:
self.child1_model = None
Expand Down Expand Up @@ -900,11 +909,12 @@ def test_mock_model_nested_with_resources(self, change_child_resource: bool, chi
child2_model_from_path: if child2_model_from_path is True, child2 model is restored from .nemo checkpoint,
otherwise constructed directly from config. Child1 model always loaded from checkpoint.
"""
with tempfile.NamedTemporaryFile('w') as file_child1, tempfile.NamedTemporaryFile(
'w'
) as file_child2, tempfile.NamedTemporaryFile('w') as file_child2_other, tempfile.NamedTemporaryFile(
'w'
) as file_parent:
with (
tempfile.NamedTemporaryFile('w') as file_child1,
tempfile.NamedTemporaryFile('w') as file_child2,
tempfile.NamedTemporaryFile('w') as file_child2_other,
tempfile.NamedTemporaryFile('w') as file_parent,
):
# write text data, use these files as resources
parent_data = ["*****\n"]
child1_data = ["+++++\n"]
Expand Down Expand Up @@ -988,11 +998,12 @@ def test_mock_model_nested_with_resources_multiple_passes(self):
Test nested model with 2 children: multiple save-restore passes
child models and parent model itself contain resources
"""
with tempfile.NamedTemporaryFile('w') as file_child1, tempfile.NamedTemporaryFile(
'w'
) as file_child2, tempfile.NamedTemporaryFile('w') as file_child2_other, tempfile.NamedTemporaryFile(
'w'
) as file_parent:
with (
tempfile.NamedTemporaryFile('w') as file_child1,
tempfile.NamedTemporaryFile('w') as file_child2,
tempfile.NamedTemporaryFile('w') as file_child2_other,
tempfile.NamedTemporaryFile('w') as file_parent,
):
# write text data, use these files as resources
parent_data = ["*****\n"]
child1_data = ["+++++\n"]
Expand All @@ -1019,7 +1030,12 @@ def test_mock_model_nested_with_resources_multiple_passes(self):
child2 = MockModelWithChildren(cfg=cfg_child2.model, trainer=None)
child2 = child2.to('cpu')

with tempfile.TemporaryDirectory() as tmpdir_parent1, tempfile.TemporaryDirectory() as tmpdir_parent2, tempfile.TemporaryDirectory() as tmpdir_parent3, tempfile.TemporaryDirectory() as tmpdir_parent4:
with (
tempfile.TemporaryDirectory() as tmpdir_parent1,
tempfile.TemporaryDirectory() as tmpdir_parent2,
tempfile.TemporaryDirectory() as tmpdir_parent3,
tempfile.TemporaryDirectory() as tmpdir_parent4,
):
parent_path1 = os.path.join(tmpdir_parent1, "parent.nemo")
parent_path2 = os.path.join(tmpdir_parent2, "parent.nemo")
with tempfile.TemporaryDirectory() as tmpdir_child:
Expand Down Expand Up @@ -1074,9 +1090,11 @@ def test_mock_model_nested_double_with_resources(self):
test nested model: parent -> child_with_child -> child; model and each child can be saved/restored separately
all models can contain resources
"""
with tempfile.NamedTemporaryFile('w') as file_child, tempfile.NamedTemporaryFile(
'w'
) as file_child_with_child, tempfile.NamedTemporaryFile('w') as file_parent:
with (
tempfile.NamedTemporaryFile('w') as file_child,
tempfile.NamedTemporaryFile('w') as file_child_with_child,
tempfile.NamedTemporaryFile('w') as file_parent,
):
# write text data, use these files as resources
parent_data = ["*****\n"]
child_with_child_data = ["+++++\n"]
Expand Down Expand Up @@ -1302,8 +1320,8 @@ class MockModelV2(MockModel):
@pytest.mark.unit
def test_hf_model_filter(self):
filt = ModelPT.get_hf_model_filter()
assert isinstance(filt, ModelFilter)
assert filt.library == 'nemo'
assert isinstance(filt, dict)
assert filt['library'] == 'nemo'

@pytest.mark.with_downloads()
@pytest.mark.unit
Expand All @@ -1312,10 +1330,12 @@ def test_hf_model_info(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=None)
model_infos = [next(model_infos) for _ in range(5)]
assert len(model_infos) > 0

# check with default override results (should match above)
default_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
default_model_infos = [next(default_model_infos) for _ in range(5)]
assert len(model_infos) == len(default_model_infos)

@pytest.mark.pleasefixme()
Expand All @@ -1326,13 +1346,12 @@ def test_hf_model_info_with_card_data(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
model_infos = [next(model_infos) for _ in range(5)]
assert len(model_infos) > 0
assert not hasattr(model_infos[0], 'cardData')

# check overriden defaults
filt.resolve_card_info = True
filt['cardData'] = True
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
assert len(model_infos) > 0

for info in model_infos:
if hasattr(info, 'cardData'):
Expand All @@ -1346,10 +1365,12 @@ def test_hf_model_info_with_limited_results(self):

# check no override results
model_infos = ModelPT.search_huggingface_models(model_filter=filt)
model_infos = [next(model_infos) for _ in range(6)]
assert len(model_infos) > 0

# check overriden defaults
filt.limit_results = 5
filt['limit'] = 5
new_model_infos = ModelPT.search_huggingface_models(model_filter=filt)
new_model_infos = list(new_model_infos)
assert len(new_model_infos) <= 5
assert len(new_model_infos) < len(model_infos)

0 comments on commit 6ff5bce

Please sign in to comment.