diff --git a/disent/util/config.py b/disent/util/config.py deleted file mode 100644 index d6962ef4..00000000 --- a/disent/util/config.py +++ /dev/null @@ -1,73 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -from deprecated import deprecated - - -# ========================================================================= # -# Recursive Hydra Instantiation # -# TODO: use https://github.com/facebookresearch/hydra/pull/989 # -# I think this is quicker? Just doesn't perform checks... # -# ========================================================================= # - - -@deprecated('replace with hydra 1.1') -def call_recursive(config): - # import hydra - try: - import hydra - from omegaconf import DictConfig - from omegaconf import ListConfig - except ImportError: - raise ImportError('please install hydra-core for call_recursive/instantiate_recursive support') - # recurse - def _call_recursive(config): - if isinstance(config, (dict, DictConfig)): - c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'} - if '_target_' in config: - config = hydra.utils.instantiate({'_target_': config['_target_']}, **c) - elif isinstance(config, (tuple, list, ListConfig)): - config = [_call_recursive(v) for v in config] - return config - return _call_recursive(config) - - -# alias -@deprecated('replace with hydra 1.1') -def instantiate_recursive(config): - return call_recursive(config) - - -@deprecated('replace with hydra 1.1') -def instantiate_object_if_needed(config_or_object): - if isinstance(config_or_object, dict): - return instantiate_recursive(config_or_object) - else: - return config_or_object - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/experiment/run.py b/experiment/run.py index 15ff09d9..537a73c2 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -29,7 +29,6 @@ import pytorch_lightning as pl import torch import torch.utils.data -from disent.util.config import instantiate_recursive from omegaconf import DictConfig from omegaconf import OmegaConf from pytorch_lightning.loggers import CometLogger @@ -49,6 +48,7 @@ from experiment.util.hydra_data import HydraDataModule from experiment.util.hydra_utils import make_non_strict from experiment.util.hydra_utils import merge_specializations +from experiment.util.hydra_utils import instantiate_recursive from experiment.util.run_utils import log_error_and_exit from experiment.util.run_utils import set_debug_logger from experiment.util.run_utils import set_debug_trainer diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index 634a8b0d..cf1d54b1 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -29,7 +29,7 @@ from disent.dataset import DisentDataset from disent.nn.transform import DisentDatasetTransform -from disent.util.config import instantiate_recursive +from experiment.util.hydra_utils import instantiate_recursive # ========================================================================= # diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 02c0ec9f..8dfa3f33 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -24,14 +24,43 @@ import logging +import hydra from deprecated import deprecated from omegaconf import DictConfig +from omegaconf import ListConfig from omegaconf import OmegaConf log = logging.getLogger(__name__) +# ========================================================================= # +# Recursive Hydra Instantiation # +# TODO: use https://github.com/facebookresearch/hydra/pull/989 # +# I think this is quicker? Just doesn't perform checks... # +# ========================================================================= # + + +@deprecated('replace with hydra 1.1') +def call_recursive(config): + # recurse + def _call_recursive(config): + if isinstance(config, (dict, DictConfig)): + c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'} + if '_target_' in config: + config = hydra.utils.instantiate({'_target_': config['_target_']}, **c) + elif isinstance(config, (tuple, list, ListConfig)): + config = [_call_recursive(v) for v in config] + return config + return _call_recursive(config) + + +# alias +@deprecated('replace with hydra 1.1') +def instantiate_recursive(config): + return call_recursive(config) + + # ========================================================================= # # Better Specializations # # TODO: this might be replaced by recursive instantiation #