diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 445497b45..0f2b85697 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -4,11 +4,14 @@ from fuse.data.tokenizers.modular_tokenizer.inject_utils import ( InjectorToModularTokenizerLib, ) +from huggingface_hub import snapshot_download, HfApi +from huggingface_hub.utils import validate_hf_hub_args, SoftTemporaryDirectory from warnings import warn +from pathlib import Path from collections import defaultdict -from typing import Tuple, Optional, Union, Any +from typing import Any, Tuple, Dict, List, Optional, Union import os import re @@ -506,3 +509,91 @@ def __call__( ) return sample_dict + + @classmethod + def from_pretrained( + cls, + identifier: str, + pad_token: str = "", + max_size: Optional[int] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[Union[str, Path]] = None, + local_files_only: bool = False, + revision: Optional[str] = None, + ) -> "ModularTokenizerOp": + """Load pre-trained tokenizer from HF repo_id or a local dirpath. + + Args: + identifier (str): A repo_id or local dirpath. + pad_token (str, optional): A string of the pad token. Defaults to "". + max_size (Optional[int], optional): Sequences below this size will be padded, and above this size will be truncated. Defaults to None. + * For other args see `snapshot_download()` + """ + if not os.path.isdir(identifier): + # Try to download from hub + try: + # Download 'tokenizer' folder from repo + identifier = snapshot_download( + repo_id=str(identifier), + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + allow_patterns="tokenizer/", + ) + # Redirecting identifier to the downloaded folder + identifier = os.path.join(identifier, "tokenizer") + except Exception as e: + raise Exception( + f"Couldn't find the checkpoint path nor download from HF hub! {identifier}" + ) from e + + tokenizer_op = cls( + tokenizer_path=identifier, pad_token=pad_token, max_size=max_size + ) + return tokenizer_op + + def save_pretrained(self, save_directory: Union[str, Path]) -> None: + print(f"Saving @ {save_directory=}") + self._tokenizer.save(path=str(save_directory)) + + @validate_hf_hub_args + def push_to_hub( + self, + repo_id: str, + *, + commit_message: str = "Push model using huggingface_hub.", + private: bool = False, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + ) -> None: + api = HfApi(token=token) + repo_id = api.create_repo( + repo_id=repo_id, private=private, exist_ok=True + ).repo_id + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + tokenzier_dirpath = saved_path / "tokenizer" + self.save_pretrained(tokenzier_dirpath) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) diff --git a/setup.cfg b/setup.cfg index aa58f1453..fc9b9cc65 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,7 @@ exclude = [mypy] -python_version = 3.7 +python_version = 3.9 warn_return_any = True warn_unused_configs = True disallow_untyped_defs = True