Skip to content

Commit

Permalink
No cache in smart update (#2952)
Browse files Browse the repository at this point in the history
* no cache in smart update

* update tests
  • Loading branch information
lhoestq authored Jun 26, 2024
1 parent 21cea15 commit ae00789
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 42 deletions.
40 changes: 15 additions & 25 deletions libs/libcommon/src/libcommon/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from dataclasses import dataclass, field
from functools import lru_cache
from http import HTTPStatus
from pathlib import Path
from typing import Optional, Union

import pandas as pd
from huggingface_hub import DatasetCard, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers, get_session
from huggingface_hub import DatasetCard, HfFileSystem
from huggingface_hub.utils import build_hf_headers, get_session

from libcommon.constants import (
CONFIG_INFO_KIND,
Expand Down Expand Up @@ -926,33 +925,24 @@ def get_impacted_files(self) -> set[str]:
def get_updated_yaml_fields_in_dataset_card(self) -> list[str]:
if "README.md" not in self.files_impacted_by_commit:
return []
fs = HfFileSystem(endpoint=self.hf_endpoint, token=self.hf_token)
try:
with Path(
hf_hub_download(
self.dataset,
"README.md",
repo_type="dataset",
token=self.hf_token,
revision=self.revision,
endpoint=self.hf_endpoint,
)
).open(mode="r", newline="", encoding="utf-8") as f:
with fs.open(
f"datasets/{self.dataset}/README.md", revision=self.revision, mode="r", newline="", encoding="utf-8"
) as f:
dataset_card_data_dict = DatasetCard(f.read()).data.to_dict()
except EntryNotFoundError: # catch file not found but raise on parsing error
except FileNotFoundError: # catch file not found but raise on parsing error
dataset_card_data_dict = {}
try:
with Path(
hf_hub_download(
self.dataset,
"README.md",
repo_type="dataset",
token=self.hf_token,
revision=self.old_revision,
endpoint=self.hf_endpoint,
)
).open(mode="r", newline="", encoding="utf-8") as f:
with fs.open(
f"datasets/{self.dataset}/README.md",
revision=self.old_revision,
mode="r",
newline="",
encoding="utf-8",
) as f:
old_dataset_card_data_dict = DatasetCard(f.read()).data.to_dict()
except EntryNotFoundError: # catch file not found but raise on parsing error
except FileNotFoundError: # catch file not found but raise on parsing error
old_dataset_card_data_dict = {}
return [
yaml_field
Expand Down
15 changes: 10 additions & 5 deletions libs/libcommon/tests/test_orchestrator_smart_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_cache_revision_is_not_parent_revision_commit() -> None:
def test_empty_commit() -> None:
# Empty commit: update the revision of the cache entries
put_cache(step=STEP_DA, dataset=DATASET_NAME, revision=OTHER_REVISION_NAME)
with put_diff(EMPTY_DIFF):
with put_diff(EMPTY_DIFF), put_readme(None):
plan = get_smart_dataset_update_plan(processing_graph=PROCESSING_GRAPH_TWO_STEPS)
assert_smart_dataset_update_plan(
plan,
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_add_initial_readme_with_config_commit() -> None:
def test_add_data() -> None:
# Add data.txt commit: raise
put_cache(step=STEP_DA, dataset=DATASET_NAME, revision=OTHER_REVISION_NAME)
with put_diff(ADD_DATA_DIFF):
with put_diff(ADD_DATA_DIFF), put_readme(None):
with pytest.raises(SmartUpdateImpossibleBecauseOfUpdatedFiles):
get_smart_dataset_update_plan(processing_graph=PROCESSING_GRAPH_TWO_STEPS)

Expand Down Expand Up @@ -231,7 +231,7 @@ def test_add_tag_commit() -> None:

def test_run() -> None:
put_cache(step=STEP_DA, dataset=DATASET_NAME, revision=OTHER_REVISION_NAME)
with put_diff(EMPTY_DIFF):
with put_diff(EMPTY_DIFF), put_readme(None):
tasks_stats = get_smart_dataset_update_plan(processing_graph=PROCESSING_GRAPH_TWO_STEPS).run()
assert tasks_stats.num_created_jobs == 0
assert tasks_stats.num_updated_cache_entries == 1
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_run_with_storage_clients(storage_client: StorageClient) -> None:
storage_client._fs.touch(storage_client.get_full_path(previous_key))
assert storage_client.exists(previous_key)
put_cache(step=STEP_DA, dataset=DATASET_NAME, revision=OTHER_REVISION_NAME)
with put_diff(EMPTY_DIFF):
with put_diff(EMPTY_DIFF), put_readme(None):
tasks_stats = get_smart_dataset_update_plan(
processing_graph=PROCESSING_GRAPH_TWO_STEPS, storage_clients=[storage_client]
).run()
Expand Down Expand Up @@ -291,7 +291,12 @@ def test_run_with_storage_clients(storage_client: StorageClient) -> None:
@pytest.mark.parametrize("out_of_order", [False, True])
def test_run_two_commits(out_of_order: bool) -> None:
put_cache(step=STEP_DA, dataset=DATASET_NAME, revision="initial_revision")
with put_diff(ADD_TAG_DIFF, revision=REVISION_NAME), put_diff(ADD_TAG_DIFF, revision=OTHER_REVISION_NAME):
with (
put_diff(ADD_TAG_DIFF, revision=REVISION_NAME),
put_diff(ADD_TAG_DIFF, revision=OTHER_REVISION_NAME),
put_readme(None, revision=REVISION_NAME),
put_readme(None, revision=OTHER_REVISION_NAME),
):

def run_plan(revisions: tuple[str, str]) -> TasksStatistics:
old_revision, revision = revisions
Expand Down
32 changes: 20 additions & 12 deletions libs/libcommon/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,24 +399,32 @@ def mock_get_diff(self: SmartDatasetUpdatePlan) -> str:

@contextmanager
def put_readme(
readme: str,
readme: Optional[str],
dataset: str = DATASET_NAME,
revision: str = REVISION_NAME,
) -> Iterator[None]:
from huggingface_hub import HfFileSystem

mocked_revision = revision
from libcommon.orchestrator import hf_hub_download as original_hf_hub_download # type: ignore[attr-defined]
original_open = HfFileSystem.open

with NamedTemporaryFile() as temp_file:
Path(temp_file.name).write_text(readme, encoding="utf-8")

def mock_hf_hub_download(repo_id: str, filename: str, revision: str, **kwargs: Any) -> str:
if filename == "README.md" and dataset == repo_id and mocked_revision == revision:
return temp_file.name
out = original_hf_hub_download(repo_id, filename, revision=revision, **kwargs)
assert isinstance(out, str)
return out

with patch("libcommon.orchestrator.hf_hub_download", mock_hf_hub_download):
if readme is not None:
Path(temp_file.name).write_text(readme, encoding="utf-8")

def maybe_open_readme(self: HfFileSystem, path: str, mode: str, **kwargs: Any) -> Any:
revision = kwargs.pop("revision")
if path == "datasets/" + dataset + "/README.md":
if readme is not None and mocked_revision == revision:
return open(temp_file.name, mode, **kwargs)
else:
try:
return original_open(self, path, mode, revision=revision, **kwargs)
except Exception:
pass
raise FileNotFoundError(path + f" at {revision=}")

with patch("libcommon.orchestrator.HfFileSystem.open", maybe_open_readme):
yield


Expand Down

0 comments on commit ae00789

Please sign in to comment.