Skip to content

Commit

Permalink
Lightweight mode in DatasetDefault which alllows good GPU utilization…
Browse files Browse the repository at this point in the history
… in large datasets (e.g. 100M samples) (#177)

* multiprocessing fix and helper func

* path utilities and also moshikos change to solve hang on tests

* PR comments implemented

* path utils

* changed default deepdiff behavior to ignore nans in comparison, added keys() items() and values() to our NDict, and tried to highlight more the faulting op in pipeline ops error

* solved static code analysis raised issues

* removed unreachable code in paths.py

* * Added "remove_extension" to path utils
* Changed default deepdiff behavior to ignore nans in comparison,
* Added keys() items() and values() to our NDict (until now it returned empty iterables for those which is incorrect)
* Tried to highlight more the faulting op in pipeline ops error

* fixed a bug in head_1D_classifier

* added a lightweight mode of DatasetDefault that doesn't hold any sample_ids. fixed a typo in samplers.py and added a describe method to NDict

* fixing statically detected issues

* added simple function caching utility

* lite weight dataset default

* fixed static checkers

* fixed static code analysis related stuff

* code cleanup

* removed comments

* implemented PR comments

Co-authored-by: Yoel Shoshan <[email protected]>
Co-authored-by: Moshiko Raboh <[email protected]>
  • Loading branch information
3 people authored Sep 29, 2022
1 parent 7d6676b commit ddcbf19
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 23 deletions.
69 changes: 47 additions & 22 deletions fuse/data/datasets/dataset_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from fuse.data.pipelines.pipeline_default import PipelineDefault
from fuse.data.datasets.caching.samples_cacher import SamplesCacher
from fuse.utils.ndict import NDict
from fuse.utils.multiprocessing.run_multiprocessed import run_multiprocessed, get_from_global_storage
from fuse.utils.multiprocessing.run_multiprocessed import (
run_multiprocessed,
get_from_global_storage,
)
from fuse.data import get_sample_id, create_initial_sample, get_specific_sample_from_potentially_morphed
import copy
from collections import OrderedDict
Expand All @@ -35,14 +38,17 @@
class DatasetDefault(DatasetBase):
def __init__(
self,
sample_ids: Sequence[Hashable],
sample_ids: Union[int, Sequence[Hashable]],
static_pipeline: Optional[PipelineDefault] = None,
dynamic_pipeline: Optional[PipelineDefault] = None,
cacher: Optional[SamplesCacher] = None,
allow_uncached_sample_morphing: bool = False,
):
"""
:param sample_ids: list of sample_ids included in dataset.
Optionally, you can provide an integer that describes only the size of the dataset. This is useful in massive datasets
(for example 100M samples). In such case, multiple functionalities will not be supported, mainly -
cacher, allow_uncached_sample_morphing and get_all_sample_ids
:param static_pipeline: static_pipeline, the output of this pipeline will be automatically cached.
:param dynamic_pipeline: dynamic_pipeline. applied sequentially after the static_pipeline, but not automatically cached.
changing it will NOT trigger recaching of the static_pipeline part.
Expand All @@ -53,39 +59,52 @@ def __init__(
super().__init__()

# store arguments
self._static_pipeline = static_pipeline
self._dynamic_pipeline = dynamic_pipeline
self._cacher = cacher
self._orig_sample_ids = sample_ids
if isinstance(sample_ids, (int, np.integer)):
if allow_uncached_sample_morphing:
raise Exception(
"allow_uncached_sample_morphing is not allowed when providing sample_ids=an integer value"
)
if cacher is not None:
raise Exception("providing a cacher is not allowed when providing sample_ids=an integer value")
self._explicit_sample_ids_mode = False
else:
self._explicit_sample_ids_mode = True

# self._orig_sample_ids = sample_ids
self._allow_uncached_sample_morphing = allow_uncached_sample_morphing

# verify unique names for dynamic pipelines
if self._dynamic_pipeline is not None and self._static_pipeline is not None:
if self._static_pipeline.get_name() == self._dynamic_pipeline.get_name():
if dynamic_pipeline is not None and static_pipeline is not None:
if static_pipeline.get_name() == dynamic_pipeline.get_name():
raise Exception(
f"Detected identical name for static pipeline and dynamic pipeline ({self._static_pipeline.get_name(self._static_pipeline.get_name())}).\nThis is not allowed, please initiate the pipelines with different names."
f"Detected identical name for static pipeline and dynamic pipeline ({static_pipeline.get_name(static_pipeline.get_name())}).\nThis is not allowed, please initiate the pipelines with different names."
)

if self._static_pipeline is None:
self._static_pipeline = PipelineDefault("dummy_static_pipeline", ops_and_kwargs=[])
if self._dynamic_pipeline is None:
self._dynamic_pipeline = PipelineDefault("dummy_dynamic_pipeline", ops_and_kwargs=[])
if static_pipeline is None:
static_pipeline = PipelineDefault("dummy_static_pipeline", ops_and_kwargs=[])
if dynamic_pipeline is None:
dynamic_pipeline = PipelineDefault("dummy_dynamic_pipeline", ops_and_kwargs=[])

if self._dynamic_pipeline is not None:
if dynamic_pipeline is not None:
assert isinstance(
self._dynamic_pipeline, PipelineDefault
), f"dynamic_pipeline may be None or a PipelineDefault instance. Instead got {type(self._dynamic_pipeline)}"
dynamic_pipeline, PipelineDefault
), f"dynamic_pipeline may be None or a PipelineDefault instance. Instead got {type(dynamic_pipeline)}"

if self._static_pipeline is not None:
if static_pipeline is not None:
assert isinstance(
self._static_pipeline, PipelineDefault
), f"static_pipeline may be None or a PipelineDefault instance. Instead got {type(self._static_pipeline)}"
static_pipeline, PipelineDefault
), f"static_pipeline may be None or a PipelineDefault instance. Instead got {type(static_pipeline)}"

if self._allow_uncached_sample_morphing:
warn(
"allow_uncached_sample_morphing is enabled! It is a significantly slower mode and should be used ONLY for debugging"
)

self._static_pipeline = static_pipeline
self._dynamic_pipeline = dynamic_pipeline
self._orig_sample_ids = copy.deepcopy(sample_ids)

self._created = False

def create(self, num_workers: int = 0, mp_context: Optional[str] = None) -> None:
Expand Down Expand Up @@ -126,14 +145,17 @@ def create(self, num_workers: int = 0, mp_context: Optional[str] = None) -> None
continue
self._final_sample_ids.extend(out_sids)
else:
self._final_sample_ids = copy.deepcopy(self._orig_sample_ids)
self._final_sample_ids = self._orig_sample_ids

self._created = True

def get_all_sample_ids(self):
if not self._created:
raise Exception("you must first call create()")

if not self._explicit_sample_ids_mode:
raise Exception("get_all_sample_ids is not supported when constructed with an integer for sample_ids")

return copy.deepcopy(self._final_sample_ids)

def __getitem__(self, item: Union[int, Hashable]) -> dict:
Expand Down Expand Up @@ -161,10 +183,10 @@ def getitem(
raise Exception("you must first call create()")

# get sample id
if isinstance(item, (int, np.integer)):
sample_id = self._final_sample_ids[item]
else:
if not isinstance(item, (int, np.integer)) or not self._explicit_sample_ids_mode:
sample_id = item
else:
sample_id = self._final_sample_ids[item]

# get collect marker info
collect_marker_info = self._get_collect_marker_info(collect_marker_name)
Expand Down Expand Up @@ -249,6 +271,9 @@ def __len__(self):
if not self._created:
raise Exception("you must first call create()")

if not self._explicit_sample_ids_mode:
return self._final_sample_ids

return len(self._final_sample_ids)

# internal methods
Expand Down
41 changes: 41 additions & 0 deletions fuse/data/ops/caching_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from typing import Callable, Any, Type, Optional, Sequence
from inspect import stack
import warnings
from fuse.utils.file_io.file_io import load_pickle, save_pickle_safe
import os
from fuse.utils.cpu_profiling import Timer
import hashlib


def get_function_call_str(func, *_args, **_kwargs) -> str:
Expand Down Expand Up @@ -142,3 +146,40 @@ def __init__(self, blah, blah2):
del curr_stack

return str_desc


# TODO: consider adding "ignore list" of args that should not participate in cache value calculation (for example - "verbose")
def run_cached_func(cache_dir: str, func: Callable, *args, **kwargs) -> Any:
"""
Will cache the function output in the first time that
it is executed, and will load from cache on the next times.
The cache hash value will be based on the function name, the args, and the code of the function.
Args:
:param cache_dir: the directory into which caches will be stored/loaded
:param func: the function to run
:param *args: positional args to provide to the function
:param **kwargs: kwargs to provide to the function
"""
os.makedirs(cache_dir, exist_ok=True)
call_str = get_function_call_str(func, *args, **kwargs)
call_hash = hashlib.md5(call_str.encode("utf-8")).hexdigest()

cache_full_file_path = os.path.join(cache_dir, call_hash + ".pkl.gz")
print(f"cache_full_file_path={cache_full_file_path}")

if os.path.isfile(cache_full_file_path):
with Timer(f"loading {cache_full_file_path}"):
ans = load_pickle(cache_full_file_path)
return ans

with Timer("run_cached_func::running func ..."):
ans = func(*args, **kwargs)

with Timer(f"saving {cache_full_file_path}"):
save_pickle_safe(ans, cache_full_file_path, compress=True)

return ans
2 changes: 1 addition & 1 deletion fuse/data/utils/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
if self._mode == "exact":
if self._batch_size % self._num_balanced_classes:
raise Exception(
f"Error: num_balanced_class ({self._num_balanced_classes}) should devide batch_size ({self._batch_size}) in exact mode."
f"Error: num_balanced_class ({self._num_balanced_classes}) should divide batch_size ({self._batch_size}) in exact mode."
)

self._balanced_class_weights = [
Expand Down
8 changes: 8 additions & 0 deletions fuse/utils/ndict.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ def _print_tree_static(data_dict: dict, level: int = 0) -> None:
print("---" * level, key)
NDict._print_tree_static(data_dict[key], level)

def describe(self) -> None:
for k in self.keypaths():
print(f"{k}")
val = self[k]
print(f"\ttype={type(val)}")
if hasattr(val, "shape"):
print(f"\tshape={val.shape}")


class NestedKeyError(KeyError):
def __init__(self, key: str, d: NDict) -> None:
Expand Down

0 comments on commit ddcbf19

Please sign in to comment.