From d1e0254ef670fb5fd2c2567bf8f92b766c7d3d8f Mon Sep 17 00:00:00 2001 From: "Moshe Raboh Moshiko.Raboh@ibm.com" Date: Mon, 26 Aug 2024 11:47:35 -0400 Subject: [PATCH] clearml offline mode --- fuse/data/ops/ops_read.py | 3 ++- fuse/dl/lightning/pl_funcs.py | 5 +++++ fuse/dl/lightning/pl_module.py | 3 +++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/fuse/data/ops/ops_read.py b/fuse/data/ops/ops_read.py index 90ac79bdd..4cc8d0837 100644 --- a/fuse/data/ops/ops_read.py +++ b/fuse/data/ops/ops_read.py @@ -136,6 +136,7 @@ def __init__( :param key_index: name of value in sample_dict which will be used as the key/index :param key_column: name of the column which use as key/index. In case of None, the original dataframe index will be used to extract the values for a single sample. """ + super().__init__() # store input self._data_filename = data_filename self._columns_to_extract = columns_to_extract @@ -146,7 +147,7 @@ def __init__( self._h5 = h5py.File(self._data_filename, "r") if self._columns_to_extract is None: - self._columns_to_extract = self._h5.keys() + self._columns_to_extract = list(self._h5.keys()) self._num_samples = len(self._h5[self._columns_to_extract[0]]) diff --git a/fuse/dl/lightning/pl_funcs.py b/fuse/dl/lightning/pl_funcs.py index f33da49f8..e6c37a2cc 100644 --- a/fuse/dl/lightning/pl_funcs.py +++ b/fuse/dl/lightning/pl_funcs.py @@ -55,6 +55,7 @@ def start_clearml_logger( auto_resource_monitoring: bool = True, auto_connect_streams: Union[bool, Mapping[str, bool]] = True, deferred_init: bool = False, + offline_mode: bool = False, ) -> TaskInstance: """ Just a fuse function to quickly start the clearml logger. It sets up patches to pytorch lightning logging hooks so it doesn't need to be passed to any lightning logger. @@ -86,6 +87,10 @@ def start_clearml_logger( bool_start_logger = True if bool_start_logger: + if offline_mode: # Use the set_offline class method before initializing a Task + Task.set_offline(offline_mode=True) + os.environ["CLEARML_OFFLINE_MODE"] = "1" + task = Task.init( project_name=project_name, task_name=task_name, diff --git a/fuse/dl/lightning/pl_module.py b/fuse/dl/lightning/pl_module.py index 003a26d70..73cf3259c 100644 --- a/fuse/dl/lightning/pl_module.py +++ b/fuse/dl/lightning/pl_module.py @@ -176,6 +176,9 @@ def __init__( ## forward def forward(self, batch_dict: NDict) -> NDict: + # workaround for fsdp + if not isinstance(batch_dict, NDict): + batch_dict = NDict(batch_dict) return self._model(batch_dict) ## Step