@@ -29,7 +29,7 @@
- Visit the docs for more info, or browse the releases.
+ Visit the docs for more info, or browse the releases.
Contributions are welcome!
@@ -42,10 +42,11 @@
- [Overview](#overview)
- [Features](#features)
+ * [Datasets](#datasets)
* [Frameworks](#frameworks)
* [Metrics](#metrics)
- * [Datasets](#datasets)
* [Schedules & Annealing](#schedules--annealing)
+- [Architecture](#architecture)
- [Examples](#examples)
* [Python Example](#python-example)
* [Hydra Config Example](#hydra-config-example)
@@ -88,65 +89,94 @@ Please use the following citation if you use Disent in your own research:
----------------------
-## Architecture
-
-The disent module structure:
+## Features
-- `disent.dataset`: dataset wrappers, datasets & sampling strategies
- + `disent.dataset.data`: raw datasets
- + `disent.dataset.sampling`: sampling strategies for `DisentDataset` when multiple elements are required by frameworks, eg. for triplet loss
- + `disent.dataset.transform`: common data transforms and augmentations
- + `disent.dataset.wrapper`: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking.
-- `disent.frameworks`: frameworks, including Auto-Encoders and VAEs
- + `disent.frameworks.ae`: Auto-Encoder based frameworks
- + `disent.frameworks.vae`: Variational Auto-Encoder based frameworks
-- `disent.metrics`: metrics for evaluating disentanglement using ground truth datasets
-- `disent.model`: common encoder and decoder models used for VAE research
-- `disent.nn`: torch components for building models including layers, transforms, losses and general maths
-- `disent.schedule`: annealing schedules that can be registered to a framework
-- `disent.util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.
+Disent includes implementations of modules, metrics and
+datasets from various papers.
-**Please Note The API Is Still Unstable โ ๏ธ**
+_Note that "๐งต" means that the dataset, framework or metric was introduced by disent!_
-Disent is still under active development. Features and APIs are mostly stable but may change! A limited
-set of tests currently exist which will be expanded upon in time.
+### Datasets
-**Hydra Experiment Directories**
+Various common datasets used in disentanglement research are included, with hash
+verification and automatic chunk-size optimization of underlying hdf5 formats for
+low-memory disk-based access.
-Easily run experiments with hydra config, these files
-are not available from `pip install`.
+Data input and target dataset augmentations and transforms are supported, as well as augmentations
+on the GPU or CPU at different points in the pipeline.
-- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config
-- `experiment/config/config.yaml`: main configuration file, this is probably what you want to edit!
-- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files
-- `experiment/util`: various helper code for experiments
+- **Ground Truth**:
+ +
+ ๐ Cars3D
+
+
+
+ +
+ โป๏ธ dSprites
+
+
+
+ +
+ ๐บ MPI3D
+ ๐ Todo
+
+
+ +
+ ๐ SmallNORB
+
+
+
+ +
+ ๐ Shapes3D
+
+
+
+ +
+
+ ๐งต dSpritesImagenet:
+ Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data.
+
+
+
-----------------------
+- **Ground Truth Synthetic**:
+ +
+
+ ๐งต XYObject:
+ A simplistic version of dSprites with a single square.
+
+
+
+
+ +
+
+ ๐งต XYObjectShaded:
+ Exact same dataset as XYObject, but ground truth factors have a different representation.
+
+
+
-## Features
+### Frameworks
-Disent includes implementations of modules, metrics and
-datasets from various papers. Please note that items marked
- with a "๐งต" are introduced in and are unique to disent!
+Disent provides the following Auto-Encoders and Variational Auto-Encoders!
-### Frameworks
- **Unsupervised**:
- + [VAE](https://arxiv.org/abs/1312.6114)
- + [Beta-VAE](https://openreview.net/forum?id=Sy2fzU9gl)
- + [DFC-VAE](https://arxiv.org/abs/1610.00291)
- + [DIP-VAE](https://arxiv.org/abs/1711.00848)
- + [InfoVAE](https://arxiv.org/abs/1706.02262)
- + [BetaTCVAE](https://arxiv.org/abs/1802.04942)
+ + AE: _Auto-Encoder_
+ + [VAE](https://arxiv.org/abs/1312.6114): Variational Auto-Encoder
+ + [Beta-VAE](https://openreview.net/forum?id=Sy2fzU9gl): VAE with Scaled Loss
+ + [DFC-VAE](https://arxiv.org/abs/1610.00291): Deep Feature Consistent VAE
+ + [DIP-VAE](https://arxiv.org/abs/1711.00848): Disentangled Inferred Prior VAE
+ + [InfoVAE](https://arxiv.org/abs/1706.02262): Information Maximizing VAE
+ + [BetaTCVAE](https://arxiv.org/abs/1802.04942): Total Correlation VAE
- **Weakly Supervised**:
- + [Ada-GVAE](https://arxiv.org/abs/2002.02886) *`AdaVae(..., average_mode='gvae')`* Usually better than the Ada-ML-VAE
- + [Ada-ML-VAE](https://arxiv.org/abs/2002.02886) *`AdaVae(..., average_mode='ml-vae')`*
+ + [Ada-GVAE](https://arxiv.org/abs/2002.02886): Adaptive GVAE, *`AdaVae.cfg(average_mode='gvae')`*, usually better than below!
+ + [Ada-ML-VAE](https://arxiv.org/abs/2002.02886): Adaptive ML-VAE, *`AdaVae.cfg(average_mode='ml-vae')`*
- **Supervised**:
- + [TVAE](https://arxiv.org/abs/1802.04403)
-
-Many popular disentanglement frameworks still need to be added, please
-submit an issue if you have a request for an additional framework.
+ + TAE: _Triplet Auto-Encoder_
+ + [TVAE](https://arxiv.org/abs/1802.04403): Triplet Variational Auto-Encoder
-todo
+๐ Todo: Many popular disentanglement frameworks still need to be added, please
+submit an issue if you have a request for an additional framework.
+ FactorVAE
+ GroupVAE
@@ -155,6 +185,9 @@ submit an issue if you have a request for an additional framework.
### Metrics
+Various metrics are provided by disent that can be used to evaluate the
+learnt representations of models that have been trained on ground-truth data.
+
- **Disentanglement**:
+ [FactorVAE Score](https://arxiv.org/abs/1802.05983)
+ [DCI](https://openreview.net/forum?id=By-7dz-AZ)
@@ -162,43 +195,14 @@ submit an issue if you have a request for an additional framework.
+ [SAP](https://arxiv.org/abs/1711.00848)
+ [Unsupervised Scores](https://github.com/google-research/disentanglement_lib)
-Some popular metrics still need to be added, please submit an issue if you wish to
-add your own, or you have a request.
-
-todo
+๐ Todo: Some popular metrics still need to be added, please submit an issue if you wish to
+add your own, or you have a request.
+ [DCIMIG](https://arxiv.org/abs/1910.05587)
+ [Modularity and Explicitness](https://arxiv.org/abs/1802.05312)
-### Datasets
-
-Various common datasets used in disentanglement research are included, with hash
-verification and automatic chunk-size optimization of underlying hdf5 formats for
-low-memory disk-based access.
-
-- **Ground Truth**:
- + Cars3D
- + dSprites
- + MPI3D
- + SmallNORB
- + Shapes3D
-
-- **Ground Truth Synthetic**:
- + ๐งต XYObject: *A simplistic version of dSprites with a single square.*
- + ๐งต XYObjectShaded: *Exact same dataset as XYObject, but ground truth factors have a different representation*
- + ๐งต DSpritesImagenet: *Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data*
-
-
-
-
-
- #### Input Transforms + Input/Target Augmentations
-
- - Input based transforms are supported.
- - Input and Target CPU and GPU based augmentations are supported.
-
### Schedules & Annealing
Hyper-parameter annealing is supported through the use of schedules.
@@ -211,6 +215,43 @@ The currently implemented schedules include:
----------------------
+## Architecture
+
+The disent module structure:
+
+- `disent.dataset`: dataset wrappers, datasets & sampling strategies
+ + `disent.dataset.data`: raw datasets
+ + `disent.dataset.sampling`: sampling strategies for `DisentDataset` when multiple elements are required by frameworks, eg. for triplet loss
+ + `disent.dataset.transform`: common data transforms and augmentations
+ + `disent.dataset.wrapper`: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking.
+- `disent.frameworks`: frameworks, including Auto-Encoders and VAEs
+ + `disent.frameworks.ae`: Auto-Encoder based frameworks
+ + `disent.frameworks.vae`: Variational Auto-Encoder based frameworks
+- `disent.metrics`: metrics for evaluating disentanglement using ground truth datasets
+- `disent.model`: common encoder and decoder models used for VAE research
+- `disent.nn`: torch components for building models including layers, transforms, losses and general maths
+- `disent.schedule`: annealing schedules that can be registered to a framework
+- `disent.util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.
+
+**โ ๏ธ The API Is _Mostly_ Stable โ ๏ธ**
+
+Disent is still under development. Features and APIs are subject to change!
+However, I will try and minimise the impact of these.
+
+A small suite of tests currently exist which will be expanded upon in time.
+
+**Hydra Experiment Directories**
+
+Easily run experiments with hydra config, these files
+are not available from `pip install`.
+
+- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config
+- `experiment/config/config.yaml`: main configuration file, this is probably what you want to edit!
+- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files
+- `experiment/util`: various helper code for experiments
+
+----------------------
+
## Examples
### Python Example
@@ -357,7 +398,7 @@ visualisations of latent traversals.
### Why?
-- Created as part of my Computer Science MSc scheduled for completion in 2021.
+- Created as part of my Computer Science MSc which ended early 2022.
- I needed custom high quality implementations of various VAE's.
- A pytorch version of [disentanglement_lib](https://github.com/google-research/disentanglement_lib).
- I didn't have time to wait for [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886) to release
diff --git a/disent/dataset/__init__.py b/disent/dataset/__init__.py
index 84f37c72..e44c1deb 100644
--- a/disent/dataset/__init__.py
+++ b/disent/dataset/__init__.py
@@ -24,3 +24,4 @@
# wrapper
from disent.dataset._base import DisentDataset
+from disent.dataset._base import DisentIterDataset
diff --git a/disent/dataset/_base.py b/disent/dataset/_base.py
index fdb3f985..7ab32153 100644
--- a/disent/dataset/_base.py
+++ b/disent/dataset/_base.py
@@ -22,7 +22,10 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+import warnings
from functools import wraps
+from typing import Callable
+from typing import Iterator
from typing import Optional
from typing import Sequence
from typing import TypeVar
@@ -30,6 +33,7 @@
import numpy as np
from torch.utils.data import Dataset
+from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import default_collate
from disent.dataset.sampling import BaseDisentSampler
@@ -81,19 +85,19 @@ def wrapper(self: 'DisentDataset', *args, **kwargs):
# ========================================================================= #
-_DO_COPY = object()
+_REF_ = object()
class DisentDataset(Dataset, LengthIter):
-
def __init__(
self,
- dataset: Union[Dataset, GroundTruthData],
+ dataset: Union[Dataset, GroundTruthData], # TODO: this should be renamed to data
sampler: Optional[BaseDisentSampler] = None,
- transform=None,
- augment=None,
- return_indices: bool = False,
+ transform: Optional[callable] = None,
+ augment: Optional[callable] = None,
+ return_indices: bool = False, # doesn't really hurt performance, might as well leave enabled by default?
+ return_factors: bool = False,
):
super().__init__()
# save attributes
@@ -102,22 +106,37 @@ def __init__(
self._transform = transform
self._augment = augment
self._return_indices = return_indices
+ self._return_factors = return_factors
+ # check sampler
+ assert isinstance(self._sampler, BaseDisentSampler), f'{DisentDataset.__name__} got an invalid {BaseDisentSampler.__name__}: {type(self._sampler)}'
# initialize sampler
if not self._sampler.is_init:
self._sampler.init(dataset)
+ # warn if we are overriding a transform
+ if self._transform is not None:
+ if hasattr(dataset, '_transform') and dataset._transform:
+ warnings.warn(f'{DisentDataset.__name__} has transform specified as well as wrapped dataset: {dataset}, are you sure this is intended?')
+ # check the dataset if we are returning the factors
+ if self._return_factors:
+ assert isinstance(self._dataset, GroundTruthData), f'If `return_factors` is `True`, then the dataset must be an instance of: {GroundTruthData.__name__}, got: {type(dataset)}'
def shallow_copy(
self,
- transform=_DO_COPY,
- augment=_DO_COPY,
- return_indices=_DO_COPY,
+ dataset: Union[Dataset, GroundTruthData] =_REF_, # TODO: this should be renamed to data
+ sampler: Optional[BaseDisentSampler] = _REF_,
+ transform: Optional[callable] = _REF_,
+ augment: Optional[callable] = _REF_,
+ return_indices: bool = _REF_,
+ return_factors: bool = _REF_,
) -> 'DisentDataset':
+ # instantiate shallow dataset copy, overwriting elements if specified
return DisentDataset(
- dataset=self._dataset,
- sampler=self._sampler,
- transform=self._transform if (transform is _DO_COPY) else transform,
- augment=self._augment if (augment is _DO_COPY) else augment,
- return_indices=self._return_indices if (return_indices is _DO_COPY) else return_indices,
+ dataset = self._dataset if (dataset is _REF_) else dataset,
+ sampler = self._sampler.uninit_copy() if (sampler is _REF_) else sampler,
+ transform = self._transform if (transform is _REF_) else transform,
+ augment = self._augment if (augment is _REF_) else augment,
+ return_indices = self._return_indices if (return_indices is _REF_) else return_indices,
+ return_factors = self._return_factors if (return_factors is _REF_) else return_factors,
)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@@ -128,6 +147,18 @@ def shallow_copy(
def data(self) -> Dataset:
return self._dataset
+ @property
+ def sampler(self) -> BaseDisentSampler:
+ return self._sampler
+
+ @property
+ def transform(self) -> Optional[Callable[[object], object]]:
+ return self._transform
+
+ @property
+ def augment(self) -> Optional[Callable[[object], object]]:
+ return self._augment
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Ground Truth Only #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@@ -176,15 +207,22 @@ def wrapped_gt_data(self):
return self._dataset.gt_data
@wrapped_only
- def unwrapped_disent_dataset(self) -> 'DisentDataset':
- sampler = self._sampler.uninit_copy()
- assert type(sampler) is type(self._sampler)
- return DisentDataset(
+ def unwrapped_shallow_copy(
+ self,
+ sampler: Optional[BaseDisentSampler] = _REF_,
+ transform: Optional[callable] = _REF_,
+ augment: Optional[callable] = _REF_,
+ return_indices: bool = _REF_,
+ return_factors: bool = _REF_,
+ ) -> 'DisentDataset':
+ # like shallow_copy, but unwrap the dataset instead!
+ return self.shallow_copy(
dataset=self.wrapped_data,
sampler=sampler,
- transform=self._transform,
- augment=self._augment,
- return_indices=self._return_indices,
+ transform=transform,
+ augment=augment,
+ return_indices=return_indices,
+ return_factors=return_factors,
)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@@ -275,6 +313,12 @@ def _dataset_get_observation(self, *idxs):
# add indices
if self._return_indices:
obs['idx'] = idxs
+ # add factors
+ if self._return_factors:
+ # >>> this is about 10% faster than below, because we do not need to do conversions!
+ obs['factors'] = tuple(np.array(np.unravel_index(idxs, self._dataset.factor_sizes)).T)
+ # >>> builtin but slower method, does some magic for more than 2 dims, could replace with faster try_njit method, but then we need numba!
+ # obs['factors1'] = tuple(self.gt_data.idx_to_pos(idxs))
# done!
return obs
@@ -285,38 +329,57 @@ def _dataset_get_observation(self, *idxs):
# TODO: default_collate should be replaced with a function
# that can handle tensors and nd.arrays, and return accordingly
- def dataset_batch_from_indices(self, indices: Sequence[int], mode: str):
+ def dataset_batch_from_indices(self, indices: Sequence[int], mode: str, collate: bool = True):
"""Get a batch of observations X from a batch of factors Y."""
- return default_collate([self.dataset_get(idx, mode=mode) for idx in indices])
+ batch = [self.dataset_get(idx, mode=mode) for idx in indices]
+ return default_collate(batch) if collate else batch
- def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False):
+ def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False, collate: bool = True, seed: Optional[int] = None):
"""Sample a batch of observations X."""
# built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672
- indices = random_choice_prng(len(self), size=num_samples, replace=replace)
+ indices = random_choice_prng(len(self._dataset), size=num_samples, replace=replace, seed=seed)
# return batch
- batch = self.dataset_batch_from_indices(indices, mode=mode)
+ batch = self.dataset_batch_from_indices(indices, mode=mode, collate=collate)
# return values
if return_indices:
- return batch, default_collate(indices)
+ return batch, (default_collate(indices) if collate else indices)
else:
return batch
+ def dataset_sample_elems(self, num_samples: int, mode: str, return_indices: bool = False, seed: Optional[int] = None):
+ """Sample uncollated elements with replacement, like `dataset_sample_batch`"""
+ return self.dataset_sample_batch(num_samples=num_samples, mode=mode, replace=True, return_indices=return_indices, collate=False, seed=seed)
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Batches -- Ground Truth Only #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@groundtruth_only
- def dataset_batch_from_factors(self, factors: np.ndarray, mode: str):
+ def dataset_batch_from_factors(self, factors: np.ndarray, mode: str, collate: bool = True):
"""Get a batch of observations X from a batch of factors Y."""
indices = self.gt_data.pos_to_idx(factors)
- return self.dataset_batch_from_indices(indices, mode=mode)
+ return self.dataset_batch_from_indices(indices, mode=mode, collate=collate)
@groundtruth_only
- def dataset_sample_batch_with_factors(self, num_samples: int, mode: str):
+ def dataset_sample_batch_with_factors(self, num_samples: int, mode: str, collate: bool = True):
"""Sample a batch of observations X and factors Y."""
factors = self.gt_data.sample_factors(num_samples)
- batch = self.dataset_batch_from_factors(factors, mode=mode)
- return batch, default_collate(factors)
+ batch = self.dataset_batch_from_factors(factors, mode=mode, collate=collate)
+ return batch, (default_collate(factors) if collate else factors)
+
+
+class DisentIterDataset(IterableDataset, DisentDataset):
+
+ # make sure we cannot obtain the length directly
+ __len__ = None
+
+ def __iter__(self):
+ # this takes priority over __getitem__, otherwise __getitem__ would need to
+ # raise an IndexError if out of bounds to signal the end of iteration
+ while True:
+ # yield the entire dataset
+ # - repeating when it is done!
+ yield from (self[i] for i in range(len(self._dataset)))
# ========================================================================= #
diff --git a/disent/dataset/data/__init__.py b/disent/dataset/data/__init__.py
index 1a870327..4dab1de0 100644
--- a/disent/dataset/data/__init__.py
+++ b/disent/dataset/data/__init__.py
@@ -43,9 +43,11 @@
# groundtruth -- impl
from disent.dataset.data._groundtruth__cars3d import Cars3dData
+from disent.dataset.data._groundtruth__cars3d import Cars3d64Data # optimized version of cars3d for 64x64 images
from disent.dataset.data._groundtruth__dsprites import DSpritesData
from disent.dataset.data._groundtruth__mpi3d import Mpi3dData
from disent.dataset.data._groundtruth__norb import SmallNorbData
+from disent.dataset.data._groundtruth__norb import SmallNorb64Data
from disent.dataset.data._groundtruth__shapes3d import Shapes3dData
# groundtruth -- impl synthetic
diff --git a/disent/dataset/data/_groundtruth.py b/disent/dataset/data/_groundtruth.py
index 0c269ba4..faa86ba1 100644
--- a/disent/dataset/data/_groundtruth.py
+++ b/disent/dataset/data/_groundtruth.py
@@ -85,6 +85,15 @@ def factor_names(self) -> Tuple[str, ...]:
def factor_sizes(self) -> Tuple[int, ...]:
raise NotImplementedError()
+ def state_space_copy(self) -> StateSpace:
+ """
+ :return: Copy this ground truth dataset as a StateSpace, discarding everything else!
+ """
+ return StateSpace(
+ factor_sizes=self.factor_sizes,
+ factor_names=self.factor_names,
+ )
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Properties #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@@ -214,7 +223,7 @@ def _mixin_disk_init(self, data_root: Optional[str] = None, prepare: bool = Fals
if data_root is None:
data_root = self.default_data_root
else:
- data_root = os.path.abspath(data_root)
+ data_root = os.path.abspath(os.path.expanduser(data_root))
# get class data folder
self._data_dir = ensure_dir_exists(os.path.join(data_root, self.name))
log.info(f'{self.name}: data_dir_share={repr(self._data_dir)}')
diff --git a/disent/dataset/data/_groundtruth__cars3d.py b/disent/dataset/data/_groundtruth__cars3d.py
index 522bc749..021c2123 100644
--- a/disent/dataset/data/_groundtruth__cars3d.py
+++ b/disent/dataset/data/_groundtruth__cars3d.py
@@ -26,13 +26,16 @@
import os
import shutil
from tempfile import TemporaryDirectory
+from typing import Dict
+from typing import Optional
+from typing import Union
import numpy as np
-from scipy.io import loadmat
-from disent.dataset.util.datafile import DataFileHashedDlGen
from disent.dataset.data._groundtruth import NumpyFileGroundTruthData
-from disent.util.inout.files import AtomicSaveFile
+from disent.dataset.util.datafile import DataFileHashed
+from disent.dataset.util.datafile import DataFileHashedDlGen
+from disent.util.inout.paths import modify_name_keep_ext
log = logging.getLogger(__name__)
@@ -52,6 +55,7 @@ def load_cars3d_folder(raw_data_dir):
2. /data/sprites
3. /data/shapes48.mat
"""
+ from scipy.io import loadmat
# load image paths
with open(os.path.join(raw_data_dir, 'cars/list.txt'), 'r') as img_names:
img_paths = [os.path.join(raw_data_dir, f'cars/{name.strip()}.mat') for name in img_names.readlines()]
@@ -72,10 +76,20 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False):
# extract zipfile and get path
log.info(f"Extracting into temporary directory: {temp_dir}")
shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir)
- # load image paths & resave
- with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file:
- images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data'))
- np.savez(temp_file, images=images)
+ # load images
+ images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data'))
+ # save the array
+ from disent.dataset.util.npz import save_dataset_array
+ save_dataset_array(images, new_save_file, overwrite=overwrite, save_key='images')
+
+
+def resave_cars3d_resized(orig_converted_file: str, new_resized_file: str, overwrite=False, size: int = 64):
+ # load the array
+ cars3d_array = np.load(orig_converted_file)['images']
+ assert cars3d_array.shape == (17568, 128, 128, 3)
+ # save the array
+ from disent.dataset.util.npz import save_resized_dataset_array
+ save_resized_dataset_array(cars3d_array, new_resized_file, overwrite=overwrite, size=size, save_key='images')
# ========================================================================= #
@@ -91,6 +105,35 @@ def _generate(self, inp_file: str, out_file: str):
resave_cars3d_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True)
+class DataFileCars3dResized(DataFileHashed):
+
+ def __init__(
+ self,
+ cars3d_datafile: DataFileCars3d,
+ # - convert file name
+ out_hash: Optional[Union[str, Dict[str, str]]],
+ out_name: Optional[str] = None,
+ out_size: int = 64,
+ # - hash settings
+ hash_type: str = 'md5',
+ hash_mode: str = 'fast',
+ ):
+ self._out_size = out_size
+ self._cars3dfile = cars3d_datafile
+ super().__init__(
+ file_name=modify_name_keep_ext(self._cars3dfile.out_name, suffix=f'_x{out_size}') if (out_name is None) else out_name,
+ file_hash=out_hash,
+ hash_type=hash_type,
+ hash_mode=hash_mode,
+ )
+
+ def _prepare(self, out_dir: str, out_file: str):
+ log.debug('Preparing Orig Cars3d Data:')
+ cars3d_path = self._cars3dfile.prepare(out_dir)
+ log.debug('Generating Resized Cars3d Data:')
+ resave_cars3d_resized(orig_converted_file=cars3d_path, new_resized_file=out_file, overwrite=True, size=self._out_size)
+
+
# ========================================================================= #
# dataset_cars3d #
# ========================================================================= #
@@ -115,7 +158,7 @@ class Cars3dData(NumpyFileGroundTruthData):
uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz',
uri_hash={'fast': 'fe77d39e3fa9d77c31df2262660c2a67', 'full': '4e866a7919c1beedf53964e6f7a23686'},
file_name='cars3d.npz',
- file_hash={'fast': 'ef5d86d1572ddb122b466ec700b3abf2', 'full': 'dc03319a0b9118fbe0e23d13220a745b'},
+ file_hash={'fast': '204ecb6852216e333f1b022903f9d012', 'full': '46ad66acf277897f0404e522460ba7e5'},
hash_mode='fast'
)
@@ -123,11 +166,50 @@ class Cars3dData(NumpyFileGroundTruthData):
data_key = 'images'
+# TODO: this is very slow compared to other datasets for some reason!
+# - in memory benchmark are equivalent, eg. against Shapes3D, but when we run the
+# experiment/run.py with this its about twice as slow? Why is this?
+class Cars3d64Data(Cars3dData):
+ """
+ Optimized version of Cars3dOrigData, that has already been re-sized to 64x64
+ - This can improve run times dramatically!
+ """
+
+ img_shape = (64, 64, 3)
+
+ datafile = DataFileCars3dResized(
+ cars3d_datafile=Cars3dData.datafile,
+ out_name='cars3d_x64.npz',
+ out_hash={'fast': '5a85246b6f555bc6e3576ee62bf6d19e', 'full': '2b900b3c5de6cd9b5df87bfc02f01f03'},
+ hash_mode='fast',
+ out_size=64,
+ )
+
+
# ========================================================================= #
# END #
# ========================================================================= #
if __name__ == '__main__':
- logging.basicConfig(level=logging.DEBUG)
- Cars3dData(prepare=True)
+
+ def main():
+ import torch
+ from tqdm import tqdm
+ from disent.dataset.transform import ToImgTensorF32
+
+ logging.basicConfig(level=logging.DEBUG)
+
+ # original dataset
+ data_128 = Cars3dData(prepare=True, transform=ToImgTensorF32(size=64))
+ for i in tqdm(data_128, desc='cars3d_x128 -> 64'):
+ pass
+ # resized dataset
+ data_64 = Cars3d64Data(prepare=True, transform=ToImgTensorF32(size=64))
+ for i in tqdm(data_64, desc='cars3d_x64'):
+ pass
+ # check equivalence
+ for obs_128, obs_64 in tqdm(zip(data_128, data_64), desc='equivalence'):
+ assert torch.allclose(obs_128, obs_64)
+
+ main()
diff --git a/disent/dataset/data/_groundtruth__dsprites.py b/disent/dataset/data/_groundtruth__dsprites.py
index 99ee76ae..d0867414 100644
--- a/disent/dataset/data/_groundtruth__dsprites.py
+++ b/disent/dataset/data/_groundtruth__dsprites.py
@@ -33,6 +33,7 @@
# ========================================================================= #
+# TODO: this seems to have a memory leak compared to other datasets?
class DSpritesData(Hdf5GroundTruthData):
"""
DSprites Dataset
diff --git a/disent/dataset/data/_groundtruth__dsprites_imagenet.py b/disent/dataset/data/_groundtruth__dsprites_imagenet.py
deleted file mode 100644
index 13b80134..00000000
--- a/disent/dataset/data/_groundtruth__dsprites_imagenet.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-# MIT License
-#
-# Copyright (c) 2021 Nathan Juraj Michlo
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-
-import logging
-import os
-import shutil
-from tempfile import TemporaryDirectory
-from typing import Optional
-
-import numpy as np
-import psutil
-from torch.utils.data import DataLoader
-from torch.utils.data import Dataset
-from torchvision.datasets import ImageFolder
-from tqdm import tqdm
-
-from disent.dataset.data import GroundTruthData
-from disent.dataset.data._groundtruth import _DiskDataMixin
-from disent.dataset.data._groundtruth import _Hdf5DataMixin
-from disent.dataset.data._groundtruth__dsprites import DSpritesData
-from disent.dataset.transform import ToImgTensorF32
-from disent.dataset.util.datafile import DataFileHashedDlGen
-from disent.dataset.util.hdf5 import H5Builder
-from disent.dataset.util.stats import compute_data_mean_std
-from disent.util.inout.files import AtomicSaveFile
-from disent.util.iters import LengthIter
-from disent.util.math.random import random_choice_prng
-
-
-log = logging.getLogger(__name__)
-
-
-# ========================================================================= #
-# load imagenet-tiny data #
-# ========================================================================= #
-
-
-class NumpyFolder(ImageFolder):
- def __getitem__(self, idx):
- img, cls = super().__getitem__(idx)
- return np.array(img)
-
-
-def _noop(x):
- return x
-
-
-def load_imagenet_tiny_data(raw_data_dir):
- data = NumpyFolder(os.path.join(raw_data_dir, 'train'))
- data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=_noop)
- # load data - this is a bit memory inefficient doing it like this instead of with a loop into a pre-allocated array
- imgs = np.concatenate(list(tqdm(data, 'loading')), axis=0)
- assert imgs.shape == (100_000, 64, 64, 3)
- return imgs
-
-
-def resave_imagenet_tiny_archive(orig_zipped_file, new_save_file, overwrite=False, h5_dataset_name: str = 'data'):
- """
- Convert a imagenet tiny archive to an hdf5 or numpy file depending on the file extension.
- Uncompressing the contents of the archive into a temporary directory in the same folder,
- loading the images, then converting.
- """
- _, ext = os.path.splitext(new_save_file)
- assert ext in {'.npz', '.h5'}, f'unsupported save extension: {repr(ext)}, must be one of: {[".npz", ".h5"]}'
- # extract zipfile into temp dir
- with TemporaryDirectory(prefix='unzip_imagenet_tiny_', dir=os.path.dirname(orig_zipped_file)) as temp_dir:
- log.info(f"Extracting into temporary directory: {temp_dir}")
- shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir)
- images = load_imagenet_tiny_data(raw_data_dir=os.path.join(temp_dir, 'tiny-imagenet-200'))
- # save the data
- with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file:
- # check the mode
- with H5Builder(temp_file, 'atomic_w') as builder:
- builder.add_dataset_from_array(
- name=h5_dataset_name,
- array=images,
- chunk_shape='batch',
- compression_lvl=4,
- attrs=None,
- show_progress=True,
- )
-
-
-# ========================================================================= #
-# cars3d data object #
-# ========================================================================= #
-
-
-class ImageNetTinyDataFile(DataFileHashedDlGen):
- """
- download the cars3d dataset and convert it to a hdf5 file.
- """
-
- dataset_name: str = 'data'
-
- def _generate(self, inp_file: str, out_file: str):
- resave_imagenet_tiny_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True, h5_dataset_name=self.dataset_name)
-
-
-class ImageNetTinyData(_Hdf5DataMixin, _DiskDataMixin, Dataset, LengthIter):
-
- name = 'imagenet_tiny'
-
- datafile_imagenet_h5 = ImageNetTinyDataFile(
- uri='http://cs231n.stanford.edu/tiny-imagenet-200.zip',
- uri_hash={'fast': '4d97ff8efe3745a3bba9917d6d536559', 'full': '90528d7ca1a48142e341f4ef8d21d0de'},
- file_hash={'fast': '9c23e8ec658b1ec9f3a86afafbdbae51', 'full': '4c32b0b53f257ac04a3afb37e3a4204e'},
- uri_name='tiny-imagenet-200.zip',
- file_name='tiny-imagenet-200.h5',
- hash_mode='full'
- )
-
- datafiles = (datafile_imagenet_h5,)
-
- def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None):
- super().__init__()
- self._transform = transform
- # initialize mixin
- self._mixin_disk_init(
- data_root=data_root,
- prepare=prepare,
- )
- self._mixin_hdf5_init(
- h5_path=os.path.join(self.data_dir, self.datafile_imagenet_h5.out_name),
- h5_dataset_name=self.datafile_imagenet_h5.dataset_name,
- in_memory=in_memory,
- )
-
- def __getitem__(self, idx: int):
- obs = self._data[idx]
- if self._transform is not None:
- obs = self._transform(obs)
- return obs
-
-
-# ========================================================================= #
-# dataset_dsprites #
-# ========================================================================= #
-
-
-class DSpritesImagenetData(GroundTruthData):
- """
- DSprites that has imagenet images in the background.
- """
-
- name = 'dsprites_imagenet'
-
- # original dsprites it only (64, 64, 1) imagenet adds the colour channel
- img_shape = (64, 64, 3)
- factor_names = DSpritesData.factor_names
- factor_sizes = DSpritesData.factor_sizes
-
- def __init__(self, visibility: int = 100, mode: str = 'bg', data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None):
- super().__init__(transform=transform)
- # check visibility and convert to ratio
- assert isinstance(visibility, int), f'incorrect visibility percentage type, expected int, got: {type(visibility)}'
- assert 0 <= visibility <= 100, f'incorrect visibility percentage: {repr(visibility)}, must be in range [0, 100]. '
- self._visibility = visibility / 100
- # check mode and convert to foreground boolean
- assert mode in {'bg', 'fg'}, f'incorrect mode: {repr(mode)}, must be one of: ["bg", "fg"]'
- self._foreground = (mode == 'fg')
- # handle the datasets
- self._dsprites = DSpritesData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None)
- self._imagenet = ImageNetTinyData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None)
- # deterministic randomization of the imagenet order
- self._imagenet_order = random_choice_prng(
- len(self._imagenet),
- size=len(self),
- seed=42,
- )
-
- def _get_observation(self, idx):
- # we need to combine the two dataset images
- # dsprites contains only {0, 255} for values
- # we can directly use these values to mask the imagenet image
- bg = self._imagenet[self._imagenet_order[idx]]
- fg = self._dsprites[idx].repeat(3, axis=-1)
- # compute background
- # set foreground
- r = self._visibility
- if self._foreground:
- # lerp content to white, and then insert into fg regions
- # r*bg + (1-r)*255
- obs = (r*bg + ((1-r)*255)).astype('uint8')
- obs[fg <= 127] = 0
- else:
- # lerp content to black, and then insert into bg regions
- # r*bg + (1-r)*000
- obs = (r*bg).astype('uint8')
- obs[fg > 127] = 255
- # checks
- return obs
-
-
-# ========================================================================= #
-# STATS #
-# ========================================================================= #
-
-
-"""
-dsprites_fg_1.0
- vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384]
- vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281]
-dsprites_fg_0.8
- vis_mean: [0.024956427531012196, 0.02336780403840578, 0.021475119672280243]
- vis_std: [0.11864125016313823, 0.11137998105649799, 0.10281424917834255]
-dsprites_fg_0.6
- vis_mean: [0.029335176871153983, 0.028145355435322966, 0.026731731769287146]
- vis_std: [0.13663242436043319, 0.13114320478634894, 0.1246542727733097]
-dsprites_fg_0.4
- vis_mean: [0.03369999506331255, 0.03290657349801835, 0.03196482946320608]
- vis_std: [0.155514074438101, 0.1518464537731621, 0.14750944591836743]
-dsprites_fg_0.2
- vis_mean: [0.038064750024334834, 0.03766780505193579, 0.03719798677641122]
- vis_std: [0.17498878664096565, 0.17315570657628318, 0.1709923319496426]
-dsprites_bg_1.0
- vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404]
- vis_std: [0.2505510666843685, 0.25007259803668697, 0.2562415603123114]
-dsprites_bg_0.8
- vis_mean: [0.40867981393820857, 0.38468564002021527, 0.34611573047508204]
- vis_std: [0.22048328737091344, 0.22102216869942384, 0.22692977053753477]
-dsprites_bg_0.6
- vis_mean: [0.31676960943447674, 0.29877166834408025, 0.2698556821388113]
- vis_std: [0.19745897110349003, 0.1986606891520453, 0.203808842880044]
-dsprites_bg_0.4
- vis_mean: [0.2248598986983768, 0.21285772298967615, 0.19359577132944206]
- vis_std: [0.1841631708032332, 0.18554895825833284, 0.1893568926398198]
-dsprites_bg_0.2
- vis_mean: [0.13294969414492142, 0.12694375140936273, 0.11733572285575933]
- vis_std: [0.18311250427586276, 0.1840916474752131, 0.18607373519458442]
-"""
-
-
-if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
-
-
- def compute_all_stats():
- from disent.util.visualize.plot import plt_subplots_imshow
-
- def compute_stats(visibility, mode):
- # plot images
- data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode)
- grid = np.array([data[i*24733] for i in np.arange(16)]).reshape([4, 4, *data.img_shape])
- plt_subplots_imshow(grid, show=True, title=f'{DSpritesImagenetData.name} visibility={repr(visibility)} mode={repr(mode)}')
- # compute stats
- name = f'dsprites_{mode}_{visibility}'
- data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode, transform=ToImgTensorF32())
- mean, std = compute_data_mean_std(data, batch_size=256, num_workers=min(psutil.cpu_count(logical=False), 64), progress=True)
- print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}')
- # return stats
- return name, mean, std
-
- # compute common stats
- stats = []
- for mode in ['fg', 'bg']:
- for vis in [100, 80, 60, 40, 20]:
- stats.append(compute_stats(vis, mode))
-
- # print once at end
- for name, mean, std in stats:
- print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}')
-
- compute_all_stats()
-
-
-# ========================================================================= #
-# END #
-# ========================================================================= #
diff --git a/disent/dataset/data/_groundtruth__norb.py b/disent/dataset/data/_groundtruth__norb.py
index c96d0505..9318b0e7 100644
--- a/disent/dataset/data/_groundtruth__norb.py
+++ b/disent/dataset/data/_groundtruth__norb.py
@@ -22,17 +22,23 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-import gzip
import logging
-import os
+from typing import Dict
+from typing import NoReturn
from typing import Optional
-from typing import Sequence
from typing import Tuple
+from typing import Union
import numpy as np
+from disent.dataset.data import NumpyFileGroundTruthData
+from disent.dataset.util.datafile import DataFile
+from disent.dataset.util.datafile import DataFileHashed
from disent.dataset.util.datafile import DataFileHashedDl
-from disent.dataset.data._groundtruth import DiskGroundTruthData
+from disent.util.inout.paths import modify_name_keep_ext
+
+
+log = logging.getLogger(__name__)
# ========================================================================= #
@@ -81,6 +87,7 @@ def read_binary_matrix_bytes(bytes):
def read_binary_matrix_file(file, gzipped: bool = True):
+ import gzip
# this does not seem to copy the bytes, which saves memory
with (gzip.open if gzipped else open)(file, "rb") as f:
return read_binary_matrix_bytes(bytes=f.read())
@@ -91,11 +98,12 @@ def read_binary_matrix_file(file, gzipped: bool = True):
# ========================================================================= #
-def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True, sort=True) -> Tuple[np.ndarray, np.ndarray]:
+def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True, sort=True, add_channel_dim: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
Load The Normalised Dataset
* dat:
- images (5 categories, 5 instances, 6 lightings, 9 elevations, and 18 azimuths)
+ + shape: (N, H, W, 1)
* cat:
- initial ground truth factor:
0. category of images (0 for animal, 1 for human, 2 for plane, 3 for truck, 4 for car).
@@ -119,16 +127,106 @@ def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True
indices = np.lexsort(factors[:, [4, 3, 2, 1, 0]].T)
images = images[indices]
factors = factors[indices]
+ # add the channel dimension
+ if add_channel_dim:
+ images = images[:, :, :, None]
+ assert images.ndim == 4
+ else:
+ assert images.ndim == 3
# done!
return images, factors
+def resave_norb_archive(in_dat_path: str, in_cat_path: str, in_info_path: str, new_save_file: str, in_gzipped=True, overwrite: bool = False):
+ # load the array
+ images, factors = read_norb_dataset(dat_path=in_dat_path, cat_path=in_cat_path, info_path=in_info_path, gzipped=in_gzipped, sort=True, add_channel_dim=True)
+ assert images.shape == (24300, 96, 96, 1)
+ # save the array
+ from disent.dataset.util.npz import save_dataset_array
+ save_dataset_array(images, new_save_file, overwrite=overwrite, save_key='images')
+
+
+def resave_norb_resized(orig_converted_file: str, new_resized_file: str, overwrite=False, size: int = 64):
+ # load the array
+ norb_array = np.load(orig_converted_file)['images']
+ assert norb_array.shape == (24300, 96, 96, 1)
+ # save the array
+ from disent.dataset.util.npz import save_resized_dataset_array
+ save_resized_dataset_array(norb_array, new_resized_file, overwrite=overwrite, size=size, save_key='images')
+
+
+# ========================================================================= #
+# Data Files #
+# ========================================================================= #
+
+
+class DataFileSmallNorb(DataFileHashed):
+ """
+ download the smallnorb dataset and convert it to a numpy file.
+ """
+
+ def __init__(
+ self,
+ datafile_dat: DataFile,
+ datafile_cat: DataFile,
+ datafile_info: DataFile,
+ out_name: str,
+ out_hash: Optional[Union[str, Dict[str, str]]],
+ hash_type: str = 'md5',
+ hash_mode: str = 'fast',
+ ):
+ self._datafile_dat = datafile_dat
+ self._datafile_cat = datafile_cat
+ self._datafile_info = datafile_info
+ # initialize
+ super().__init__(file_name=out_name, file_hash=out_hash, hash_type=hash_type, hash_mode=hash_mode)
+
+ def _prepare(self, out_dir: str, out_file: str) -> NoReturn:
+ resave_norb_archive(
+ in_dat_path=self._datafile_dat.prepare(out_dir),
+ in_cat_path=self._datafile_cat.prepare(out_dir),
+ in_info_path=self._datafile_info.prepare(out_dir),
+ new_save_file=out_file,
+ in_gzipped=True,
+ overwrite=True,
+ )
+
+
+class DataFileSmallNorbResized(DataFileHashed):
+
+ def __init__(
+ self,
+ norb_datafile: DataFileSmallNorb,
+ # - convert file name
+ out_hash: Optional[Union[str, Dict[str, str]]],
+ out_name: Optional[str] = None,
+ out_size: int = 64,
+ # - hash settings
+ hash_type: str = 'md5',
+ hash_mode: str = 'fast',
+ ):
+ self._out_size = out_size
+ self._norb_datafile = norb_datafile
+ super().__init__(
+ file_name=modify_name_keep_ext(self._norb_datafile.out_name, suffix=f'_x{out_size}') if (out_name is None) else out_name,
+ file_hash=out_hash,
+ hash_type=hash_type,
+ hash_mode=hash_mode,
+ )
+
+ def _prepare(self, out_dir: str, out_file: str):
+ log.debug('Preparing Orig SmallNorb Data:')
+ norb_path = self._norb_datafile.prepare(out_dir)
+ log.debug('Generating Resized SmallNorb Data:')
+ resave_norb_resized(orig_converted_file=norb_path, new_resized_file=out_file, overwrite=True, size=self._out_size)
+
+
# ========================================================================= #
# dataset_norb #
# ========================================================================= #
-class SmallNorbData(DiskGroundTruthData):
+class SmallNorbData(NumpyFileGroundTruthData):
"""
Small NORB Dataset
- https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/
@@ -143,33 +241,56 @@ class SmallNorbData(DiskGroundTruthData):
factor_sizes = (5, 5, 9, 18, 6) # TOTAL: 24300
img_shape = (96, 96, 1)
- TRAIN_DATA_FILES = {
- 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}),
- 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}),
- 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}),
- }
-
- TEST_DATA_FILES = {
- 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}),
- 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}),
- 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}),
- }
-
- def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_test=False, transform=None):
+ DATA_FILE_TRAIN = DataFileSmallNorb(
+ datafile_dat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}),
+ datafile_cat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}),
+ datafile_info=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}),
+ out_name='smallnorb_train.npz',
+ out_hash={'fast': 'a2c7de23c57b16c71b79dc2c884ecd67', 'full': '7dabafbfafa0eb9b0115452f82d1491e'},
+ hash_mode='fast',
+ )
+
+ DATA_FILE_TEST = DataFileSmallNorb(
+ datafile_dat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}),
+ datafile_cat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}),
+ datafile_info=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}),
+ out_name='smallnorb_test.npz',
+ out_hash={'fast': 'ff027c01e14faea9a0d427d641d3bd8e', 'full': 'cec16d74a1b075fe4117b5167e16ceff'},
+ hash_mode='fast',
+ )
+
+ # override
+ data_key = 'images'
+
+ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_test: bool = False, transform=None):
self._is_test = is_test
# initialize
super().__init__(data_root=data_root, prepare=prepare, transform=transform)
- # read dataset and sort by features
- dat_path, cat_path, info_path = (os.path.join(self.data_dir, obj.out_name) for obj in self.datafiles)
- self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path)
-
- def _get_observation(self, idx):
- return self._data[idx][:, :, None] # data is missing channel dim
@property
- def datafiles(self) -> Sequence[DataFileHashedDl]:
- norb_objects = self.TEST_DATA_FILES if self._is_test else self.TRAIN_DATA_FILES
- return norb_objects['dat'], norb_objects['cat'], norb_objects['info']
+ def datafile(self) -> DataFile:
+ return self.DATA_FILE_TEST if self._is_test else self.DATA_FILE_TRAIN
+
+
+class SmallNorb64Data(SmallNorbData):
+
+ img_shape = (64, 64, 1)
+
+ DATA_FILE_TRAIN = DataFileSmallNorbResized(
+ norb_datafile=SmallNorbData.DATA_FILE_TRAIN,
+ out_name='smallnorb_train_x64.npz',
+ out_size=64,
+ out_hash={'fast': '74a3c02ea5a649313ea245a3fe271d3b', 'full': '88ce361b2198ee577e60da2be9daa0e8'},
+ hash_mode='fast',
+ )
+
+ DATA_FILE_TEST = DataFileSmallNorbResized(
+ norb_datafile=SmallNorbData.DATA_FILE_TEST,
+ out_name='smallnorb_test_x64.npz',
+ out_size=64,
+ out_hash={'fast': '37bf364479c0954ecd707ace349541ef', 'full': '6bfd93eb6454d9d24dba13cac5f1ef3e'},
+ hash_mode='fast',
+ )
# ========================================================================= #
@@ -178,5 +299,21 @@ def datafiles(self) -> Sequence[DataFileHashedDl]:
if __name__ == '__main__':
+ import torch
+ from tqdm import tqdm
+ from disent.dataset.transform import ToImgTensorF32
+
logging.basicConfig(level=logging.DEBUG)
- SmallNorbData(prepare=True)
+
+ for is_test in [False, True]:
+ # original dataset
+ data_96 = SmallNorbData(prepare=True, is_test=is_test, transform=ToImgTensorF32(size=64))
+ for i in tqdm(data_96, desc='norb_x96 -> 64'):
+ pass
+ # resized dataset
+ data_64 = SmallNorb64Data(prepare=True, is_test=is_test, transform=ToImgTensorF32(size=64))
+ for i in tqdm(data_64, desc='norb_x64'):
+ pass
+ # check equivalence
+ for obs_96, obs_64 in tqdm(zip(data_96, data_64), desc='equivalence'):
+ assert torch.allclose(obs_96, obs_64)
diff --git a/disent/dataset/data/_groundtruth__teapots3d.py b/disent/dataset/data/_groundtruth__teapots3d.py
new file mode 100644
index 00000000..afaf991f
--- /dev/null
+++ b/disent/dataset/data/_groundtruth__teapots3d.py
@@ -0,0 +1,164 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+import os
+from typing import Dict
+from typing import NoReturn
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+from disent.dataset.data import NumpyFileGroundTruthData
+from disent.dataset.util.datafile import DataFileHashed
+from disent.util.inout.paths import modify_name_keep_ext
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# teapots 3d data processing #
+# ========================================================================= #
+
+
+def resave_teapots3d_as_uint8(orig_file: str, new_file: str, overwrite: bool = False):
+ # load data into memory ~10GB
+ # -- by default this array is stored as uint32, instead of uint8
+ log.debug('loading teapots data into memory, this may take a while...')
+ imgs = np.load(orig_file)['images']
+ log.debug('loaded teapots data into memory!')
+ # checks
+ log.debug('checking teapots data...')
+ assert imgs.dtype == 'int32'
+ assert imgs.shape == (200_000, 64, 64, 3)
+ assert imgs.max() == 255
+ assert imgs.min() == 0
+ log.debug('checked teapots data!')
+ # convert the values
+ log.debug('converting teapots data to uint8...')
+ imgs = imgs.astype('uint8')
+ log.debug('converted teapots data!')
+ # save the array
+ from disent.dataset.util.npz import save_dataset_array
+ log.debug('saving convert teapots data...')
+ save_dataset_array(imgs, new_file, overwrite=overwrite, save_key='images')
+ log.debug('saved convert teapots data!')
+
+
+# ========================================================================= #
+# teapots 3d data files #
+# ========================================================================= #
+
+
+class DataFileTeapots3dInt32(DataFileHashed):
+
+ # TODO: add a version of this file that automatically unpacks the original zip file?
+
+ def _prepare(self, out_dir: str, out_file: str) -> NoReturn:
+ if not os.path.exists(out_file):
+ raise FileNotFoundError(
+ f'Please download the Teapots3D dataset to: {repr(out_file)}'
+ f'\nThe original repository is: {repr("https://github.com/cianeastwood/qedr")}'
+ f'\nThe original download link is: {repr("https://www.dropbox.com/s/woeyomxuylqu7tx/edinburgh_teapots.zip?dl=0")}'
+ )
+
+
+class DataFileTeapots3dUint8(DataFileHashed):
+
+ def __init__(
+ self,
+ teapots3d_datafile: DataFileTeapots3dInt32,
+ # - convert file name
+ out_hash: Optional[Union[str, Dict[str, str]]],
+ out_name: Optional[str] = None,
+ # - hash settings
+ hash_type: str = 'md5',
+ hash_mode: str = 'fast',
+ ):
+ self._teapots3dfile = teapots3d_datafile
+ super().__init__(
+ file_name=modify_name_keep_ext(self._teapots3dfile.out_name, suffix=f'_uint8') if (out_name is None) else out_name,
+ file_hash=out_hash,
+ hash_type=hash_type,
+ hash_mode=hash_mode,
+ )
+
+ def _prepare(self, out_dir: str, out_file: str):
+ log.debug('Preparing Orig Teapots3d Data:')
+ cars3d_path = self._teapots3dfile.prepare(out_dir)
+ log.debug('Converting Teapots3d Data to Uint8:')
+ resave_teapots3d_as_uint8(orig_file=cars3d_path, new_file=out_file, overwrite=True)
+
+
+# ========================================================================= #
+# teapots 3d dataset #
+# ========================================================================= #
+
+
+class Teapots3dData(NumpyFileGroundTruthData):
+ """
+ Teapots3D Dataset
+ - A Framework for the Quantitative Evaluation of Disentangled Representations
+ * https://openreview.net/forum?id=By-7dz-AZ
+ * https://github.com/cianeastwood/qedr
+
+ Manual Download Link:
+ - https://www.dropbox.com/s/woeyomxuylqu7tx/edinburgh_teapots.zip?dl=0
+
+ NOTE:
+ - This dataset is generated from ground-truth factors, HOWEVER, each datapoint
+ is randomly sampled. This dataset is NOT a typical grid-search over ground-truth factors
+ which means that we cannot create a StateSpace object over this dataset.
+ """
+
+ name = 'edinburgh_teapots'
+
+ factor_names = ('azimuth', 'elevation', 'red', 'green', 'blue')
+ factor_sizes = (..., ..., ..., ..., ...) # TOTAL: 200_000 -- TODO: this is invalid, we cannot actually generate a StateSpace object over this dataset!
+ img_shape = (64, 64, 3)
+
+ datafile = DataFileTeapots3dUint8(
+ teapots3d_datafile=DataFileTeapots3dInt32(
+ file_name='teapots.npz',
+ file_hash={'full': '9b58d66a382d01f4477e33520f1fa503', 'fast': '12c889e001c205d0bafa59dfff114102'},
+ ),
+ out_hash={'full': 'e64207ee443030d310500d762f0d1dfd', 'fast': '7fbca0223c27e055d35b6d5af720f108'},
+ out_name='teapots_uint8.npz',
+ )
+
+ # override
+ data_key = 'images'
+
+
+# ========================================================================= #
+# main #
+# ========================================================================= #
+
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ data = Teapots3dData(data_root='~/Downloads', prepare=True)
diff --git a/disent/dataset/data/_raw.py b/disent/dataset/data/_raw.py
index 219cda98..5a30f45c 100644
--- a/disent/dataset/data/_raw.py
+++ b/disent/dataset/data/_raw.py
@@ -29,7 +29,7 @@
# ========================================================================= #
-# hdf5 utils #
+# array utils #
# ========================================================================= #
diff --git a/disent/dataset/sampling/__init__.py b/disent/dataset/sampling/__init__.py
index b9bcaf8f..5dd9a9c0 100644
--- a/disent/dataset/sampling/__init__.py
+++ b/disent/dataset/sampling/__init__.py
@@ -31,6 +31,7 @@
from disent.dataset.sampling._groundtruth__pair_orig import GroundTruthPairOrigSampler
from disent.dataset.sampling._groundtruth__single import GroundTruthSingleSampler
from disent.dataset.sampling._groundtruth__triplet import GroundTruthTripleSampler
+from disent.dataset.sampling._groundtruth__walk import GroundTruthRandomWalkSampler
# any dataset samplers
from disent.dataset.sampling._single import SingleSampler
diff --git a/disent/dataset/sampling/_base.py b/disent/dataset/sampling/_base.py
index d678dc6c..5bed3bb9 100644
--- a/disent/dataset/sampling/_base.py
+++ b/disent/dataset/sampling/_base.py
@@ -67,7 +67,7 @@ def is_init(self) -> bool:
def _sample_idx(self, idx: int) -> Tuple[int, ...]:
raise NotImplementedError
- def __call__(self, idx: int) -> Tuple[int, ...]:
+ def sample(self, idx: int) -> Tuple[int, ...]:
# check that we have been initialized!
if not self.is_init:
raise RuntimeError(f'{self.__class__.__name__} has not been initialized! call `sampler.init(gt_data)`')
@@ -79,6 +79,9 @@ def __call__(self, idx: int) -> Tuple[int, ...]:
# return values
return idxs
+ def __call__(self, idx: int) -> Tuple[int, ...]:
+ return self.sample(idx)
+
# ========================================================================= #
# END #
diff --git a/disent/dataset/sampling/_groundtruth__dist.py b/disent/dataset/sampling/_groundtruth__dist.py
index 8f2a89d8..67c17bec 100644
--- a/disent/dataset/sampling/_groundtruth__dist.py
+++ b/disent/dataset/sampling/_groundtruth__dist.py
@@ -22,10 +22,21 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+from fractions import Fraction
+from typing import List
+from typing import Optional
+from typing import Union
+
import numpy as np
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling._base import BaseDisentSampler
+from disent.dataset.util.state_space import StateSpace
+
+
+# ========================================================================= #
+# Ground Truth Dist Sampler #
+# ========================================================================= #
class GroundTruthDistSampler(BaseDisentSampler):
@@ -63,11 +74,11 @@ def __init__(
self._sample_mode = triplet_sample_mode
self._swap_chance = triplet_swap_chance
# dataset variable
- self._data: GroundTruthData
+ self._state_space: Optional[StateSpace] = None
def _init(self, dataset):
assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
- self._data = dataset
+ self._state_space = dataset.state_space_copy()
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Sampling #
@@ -75,7 +86,7 @@ def _init(self, dataset):
def _sample_idx(self, idx):
# sample indices
- indices = (idx, *np.random.randint(0, len(self._data), size=self._num_samples-1))
+ indices = (idx, *np.random.randint(0, len(self._state_space), size=self._num_samples-1))
# sort based on mode
if self._num_samples == 3:
a_i, p_i, n_i = self._swap_triple(indices)
@@ -89,31 +100,26 @@ def _sample_idx(self, idx):
def _swap_triple(self, indices):
a_i, p_i, n_i = indices
- a_f, p_f, n_f = self._data.idx_to_pos(indices)
- a_d, p_d, n_d = a_f, p_f, n_f
- # dists vars
- if self._scaled:
- # range of positions is [0, f_size - 1], to scale between 0 and 1 we need to
- # divide by (f_size - 1), but if the factor size is 1, we can't divide by zero
- # so we make the minimum 1.0
- scale = np.maximum(1, np.array(self._data.factor_sizes) - 1)
- a_d = a_d / scale
- p_d = p_d / scale
- n_d = n_d / scale
+ a_f, p_f, n_f = self._state_space.idx_to_pos(indices)
+ # get the scale for everything
+ # - range of positions is [0, f_size - 1], to scale between 0 and 1 we need to
+ # divide by (f_size - 1), but if the factor size is 1, we can't divide by zero
+ # so we make the minimum 1
+ scale = np.maximum(1, self._state_space.factor_sizes - 1) if (self._scaled) else None
+ # SWAP: manhattan
+ if self._sample_mode == 'manhattan':
+ if factor_dist(a_f, p_f, scale=scale) > factor_dist(a_f, n_f, scale=scale):
+ return a_i, n_i, p_i
# SWAP: factors
- if self._sample_mode == 'factors':
+ elif self._sample_mode == 'factors':
if factor_diff(a_f, p_f) > factor_diff(a_f, n_f):
return a_i, n_i, p_i
- # SWAP: manhattan
- elif self._sample_mode == 'manhattan':
- if factor_dist(a_d, p_d) > factor_dist(a_d, n_d):
- return a_i, n_i, p_i
# SWAP: combined
elif self._sample_mode == 'combined':
if factor_diff(a_f, p_f) > factor_diff(a_f, n_f):
return a_i, n_i, p_i
elif factor_diff(a_f, p_f) == factor_diff(a_f, n_f):
- if factor_dist(a_d, p_d) > factor_dist(a_d, n_d):
+ if factor_dist(a_f, p_f, scale=scale) > factor_dist(a_f, n_f, scale=scale):
return a_i, n_i, p_i
# SWAP: random
elif self._sample_mode != 'random':
@@ -122,9 +128,114 @@ def _swap_triple(self, indices):
return indices
-def factor_diff(f0, f1):
+def factor_diff(f0: np.ndarray, f1: np.ndarray) -> int:
+ # input types should be np.int64
+ assert f0.dtype == f1.dtype == 'int64'
+ # compute distances!
return np.sum(f0 != f1)
-def factor_dist(f0, f1):
- return np.sum(np.abs(f0 - f1))
+# NOTE: scaling here should always be the same as `disentangle_loss`
+def factor_dist(f0: np.ndarray, f1: np.ndarray, scale: np.ndarray = None) -> Union[Fraction, int]:
+ # compute distances!
+ if scale is None:
+ # input types should all be np.int64
+ assert f0.dtype == f1.dtype == 'int64', f'invalid dtypes, f0: {f0.dtype}, f1: {f1.dtype}'
+ # we can simply sum if everything is already an integer
+ return np.sum(np.abs(f0 - f1))
+ else:
+ # input types should all be np.int64
+ assert f0.dtype == f1.dtype == scale.dtype == 'int64'
+ # Division results in precision errors! We cannot simply sum divided values. We instead
+ # store values as arbitrary precision rational numbers in the form of fractions This means
+ # we do not lose precision while summing, and avoid comparison errors!
+ # - https://shlegeris.com/2018/10/23/sqrt.html
+ # - https://cstheory.stackexchange.com/a/4010
+ # 1. first we need to convert numbers to python arbitrary precision values:
+ f0: List[int] = f0.tolist()
+ f1: List[int] = f1.tolist()
+ scale: List[int] = scale.tolist()
+ # 2. we need to sum values in the form of fractions
+ total = Fraction(0)
+ for y0, y1, s in zip(f0, f1, scale):
+ total += Fraction(abs(y0 - y1), s)
+ return total
+
+
+# ========================================================================= #
+# Investigation: #
+# ========================================================================= #
+
+
+if __name__ == '__main__':
+
+ def main():
+ from disent.dataset import DisentDataset
+ from disent.dataset.data import XYObjectData
+ from disent.dataset.data import XYObjectShadedData
+ from disent.dataset.data import Cars3d64Data
+ from disent.dataset.data import Shapes3dData
+ from disent.dataset.data import DSpritesData
+ from disent.dataset.data import SmallNorb64Data
+ from disent.util.seeds import TempNumpySeed
+ from tqdm import tqdm
+
+ repeats = 1000
+ samples = 100
+
+ # RESULTS - manhattan:
+ # cars3d: orig_vs_divs=30.066%, orig_vs_frac=30.066%, divs_vs_frac=0.000%
+ # 3dshapes: orig_vs_divs=12.902%, orig_vs_frac=12.878%, divs_vs_frac=0.096%
+ # dsprites: orig_vs_divs=24.035%, orig_vs_frac=24.032%, divs_vs_frac=0.003%
+ # smallnorb: orig_vs_divs=18.601%, orig_vs_frac=18.598%, divs_vs_frac=0.005%
+ # xy_squares_minimal: orig_vs_divs= 1.389%, orig_vs_frac= 0.000%, divs_vs_frac=1.389%
+ # xy_object: orig_vs_divs=15.520%, orig_vs_frac=15.511%, divs_vs_frac=0.029%
+ # xy_object: orig_vs_divs=23.973%, orig_vs_frac=23.957%, divs_vs_frac=0.082%
+ # RESULTS - combined:
+ # cars3d: orig_vs_divs=15.428%, orig_vs_frac=15.428%, divs_vs_frac=0.000%
+ # 3dshapes: orig_vs_divs=4.982%, orig_vs_frac= 4.968%, divs_vs_frac=0.050%
+ # dsprites: orig_vs_divs=8.366%, orig_vs_frac= 8.363%, divs_vs_frac=0.003%
+ # smallnorb: orig_vs_divs=7.359%, orig_vs_frac= 7.359%, divs_vs_frac=0.000%
+ # xy_squares_minimal: orig_vs_divs=0.610%, orig_vs_frac= 0.000%, divs_vs_frac=0.610%
+ # xy_object: orig_vs_divs=7.622%, orig_vs_frac= 7.614%, divs_vs_frac=0.020%
+ # xy_object: orig_vs_divs=8.741%, orig_vs_frac= 8.733%, divs_vs_frac=0.046%
+ for mode in ['manhattan', 'combined']:
+ for data_cls in [
+ Cars3d64Data,
+ Shapes3dData,
+ DSpritesData,
+ SmallNorb64Data,
+ XYObjectData,
+ XYObjectShadedData,
+ ]:
+ data = data_cls()
+ dataset_orig = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}'))
+ dataset_frac = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}_scaled'))
+ dataset_divs = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}_scaled_INVALID'))
+ # calculate the average number of mismatches between sampling methods!
+ all_wrong_frac = [] # frac vs orig
+ all_wrong_divs = [] # divs vs orig
+ all_wrong_diff = [] # frac vs divs
+ with TempNumpySeed(777):
+ progress = tqdm(range(repeats), desc=f'{mode} {data.name}')
+ for i in progress:
+ batch_seed = np.random.randint(0, 2**32)
+ with TempNumpySeed(batch_seed): idxs_orig = np.array([dataset_orig.sampler.sample(np.random.randint(0, len(dataset_orig))) for _ in range(samples)])
+ with TempNumpySeed(batch_seed): idxs_frac = np.array([dataset_frac.sampler.sample(np.random.randint(0, len(dataset_frac))) for _ in range(samples)])
+ with TempNumpySeed(batch_seed): idxs_divs = np.array([dataset_divs.sampler.sample(np.random.randint(0, len(dataset_divs))) for _ in range(samples)])
+ # check number of miss_matches
+ all_wrong_frac.append(np.sum(np.any(idxs_orig != idxs_frac, axis=-1)) / samples * 100)
+ all_wrong_divs.append(np.sum(np.any(idxs_orig != idxs_divs, axis=-1)) / samples * 100)
+ all_wrong_diff.append(np.sum(np.any(idxs_frac != idxs_divs, axis=-1)) / samples * 100)
+ # update progress bar
+ progress.set_postfix({
+ 'orig_vs_divs': f'{np.mean(all_wrong_divs):5.3f}%',
+ 'orig_vs_frac': f'{np.mean(all_wrong_frac):5.3f}%',
+ 'divs_vs_frac': f'{np.mean(all_wrong_diff):5.3f}%',
+ })
+ main()
+
+
+# ========================================================================= #
+# END: #
+# ========================================================================= #
diff --git a/disent/dataset/sampling/_groundtruth__pair.py b/disent/dataset/sampling/_groundtruth__pair.py
index 753bc1aa..9b7dcc32 100644
--- a/disent/dataset/sampling/_groundtruth__pair.py
+++ b/disent/dataset/sampling/_groundtruth__pair.py
@@ -22,10 +22,13 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+from typing import Optional
+
import numpy as np
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling._base import BaseDisentSampler
from disent.dataset.sampling._groundtruth__triplet import normalise_range_pair, FactorSizeError
+from disent.dataset.util.state_space import StateSpace
from disent.util.math.random import sample_radius
@@ -60,15 +63,15 @@ def __init__(
self.p_k_range = p_k_range
self.p_radius_range = p_radius_range
# dataset variable
- self._data: GroundTruthData
+ self._state_space: Optional[StateSpace]
def _init(self, dataset):
assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
- self._data = dataset
+ self._state_space = dataset.state_space_copy()
# DIFFERING FACTORS
- self.p_k_min, self.p_k_max = self._min_max_from_range(p_range=self.p_k_range, max_values=self._data.num_factors)
+ self.p_k_min, self.p_k_max = self._min_max_from_range(p_range=self.p_k_range, max_values=self._state_space.num_factors)
# RADIUS SAMPLING
- self.p_radius_min, self.p_radius_max = self._min_max_from_range(p_range=self.p_radius_range, max_values=self._data.factor_sizes)
+ self.p_radius_min, self.p_radius_max = self._min_max_from_range(p_range=self.p_radius_range, max_values=self._state_space.factor_sizes)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# CORE #
@@ -77,8 +80,8 @@ def _init(self, dataset):
def _sample_idx(self, idx):
f0, f1 = self.datapoint_sample_factors_pair(idx)
return (
- self._data.pos_to_idx(f0),
- self._data.pos_to_idx(f1),
+ self._state_space.pos_to_idx(f0),
+ self._state_space.pos_to_idx(f1),
)
def datapoint_sample_factors_pair(self, idx):
@@ -102,7 +105,7 @@ def datapoint_sample_factors_pair(self, idx):
p_k = self._sample_num_factors()
p_shared_indices = self._sample_shared_indices(p_k)
# SAMPLE FACTORS - sample, resample and replace shared factors with originals
- anchor_factors = self._data.idx_to_pos(idx)
+ anchor_factors = self._state_space.idx_to_pos(idx)
positive_factors = self._resample_factors(anchor_factors)
positive_factors[p_shared_indices] = anchor_factors[p_shared_indices]
return anchor_factors, positive_factors
@@ -125,11 +128,11 @@ def _sample_num_factors(self):
return p_k
def _sample_shared_indices(self, p_k):
- p_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-p_k, replace=False)
+ p_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-p_k, replace=False)
return p_shared_indices
def _resample_factors(self, anchor_factors):
- positive_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1)
+ positive_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1)
return positive_factors
diff --git a/disent/dataset/sampling/_groundtruth__pair_orig.py b/disent/dataset/sampling/_groundtruth__pair_orig.py
index 02530fb0..76ec9a91 100644
--- a/disent/dataset/sampling/_groundtruth__pair_orig.py
+++ b/disent/dataset/sampling/_groundtruth__pair_orig.py
@@ -22,9 +22,12 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+from typing import Optional
+
import numpy as np
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling._base import BaseDisentSampler
+from disent.dataset.util.state_space import StateSpace
class GroundTruthPairOrigSampler(BaseDisentSampler):
@@ -47,11 +50,11 @@ def __init__(
# DIFFERING FACTORS
self.p_k = p_k
# dataset variable
- self._data: GroundTruthData
+ self._state_space: Optional[StateSpace] = None
def _init(self, dataset):
assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
- self._data = dataset
+ self._state_space = dataset.state_space_copy()
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# CORE #
@@ -60,8 +63,8 @@ def _init(self, dataset):
def _sample_idx(self, idx):
f0, f1 = self.datapoint_sample_factors_pair(idx)
return (
- self._data.pos_to_idx(f0),
- self._data.pos_to_idx(f1),
+ self._state_space.pos_to_idx(f0),
+ self._state_space.pos_to_idx(f1),
)
def datapoint_sample_factors_pair(self, idx):
@@ -70,14 +73,14 @@ def datapoint_sample_factors_pair(self, idx):
Except deterministic for the first item in the pair, based off of idx.
"""
# randomly sample the first observation -- In our case we just use the idx
- sampled_factors = self._data.idx_to_pos(idx)
+ sampled_factors = self._state_space.idx_to_pos(idx)
# sample the next observation with k differing factors
- next_factors, k = _sample_k_differing(sampled_factors, self._data, k=self.p_k)
+ next_factors, k = _sample_k_differing(sampled_factors, self._state_space, k=self.p_k)
# return the samples
return sampled_factors, next_factors
-def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1):
+def _sample_k_differing(factors, state_space: StateSpace, k=1):
"""
Resample the factors used for the corresponding item in a pair.
- Based on simple_dynamics() from:
@@ -88,7 +91,7 @@ def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1):
assert factors.ndim == 1
# sample k
if k <= 0:
- k = np.random.randint(1, ground_truth_data.num_factors)
+ k = np.random.randint(1, state_space.num_factors)
# randomly choose 1 or k
# TODO: This is in disentanglement lib, HOWEVER is this not a mistake?
# A bug report has been submitted to disentanglement_lib for clarity:
@@ -98,20 +101,20 @@ def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1):
index_list = np.random.choice(len(factors), k, replace=False)
# randomly update factors
for index in index_list:
- factors[index] = np.random.choice(ground_truth_data.factor_sizes[index])
+ factors[index] = np.random.choice(state_space.factor_sizes[index])
# return!
return factors, k
-def _sample_weak_pair_factors(gt_data: GroundTruthData): # pragma: no cover
+def _sample_weak_pair_factors(state_space: StateSpace): # pragma: no cover
"""
Sample a weakly supervised pair from the given GroundTruthData.
- Based on weak_dataset_generator() from:
https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/weak/train_weak_lib.py
"""
# randomly sample the first observation
- sampled_factors = gt_data.sample_factors(1)
+ sampled_factors = state_space.sample_factors(1)
# sample the next observation with k differing factors
- next_factors, k = _sample_k_differing(sampled_factors, gt_data, k=1)
+ next_factors, k = _sample_k_differing(sampled_factors, state_space, k=1)
# return the samples
return sampled_factors, next_factors
diff --git a/disent/dataset/sampling/_groundtruth__single.py b/disent/dataset/sampling/_groundtruth__single.py
index 885bc5ca..8198e9b0 100644
--- a/disent/dataset/sampling/_groundtruth__single.py
+++ b/disent/dataset/sampling/_groundtruth__single.py
@@ -23,8 +23,11 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
import logging
+from typing import Optional
+
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling._base import BaseDisentSampler
+from disent.dataset.util.state_space import StateSpace
log = logging.getLogger(__name__)
@@ -42,10 +45,11 @@ def uninit_copy(self) -> 'GroundTruthSingleSampler':
def __init__(self):
super().__init__(num_samples=1)
+ self._state_space: Optional[StateSpace] = None # TODO: not actually needed
def _init(self, dataset):
assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
- self._data = dataset
+ self._state_space = dataset.state_space_copy()
def _sample_idx(self, idx):
return (idx,)
diff --git a/disent/dataset/sampling/_groundtruth__triplet.py b/disent/dataset/sampling/_groundtruth__triplet.py
index 231221a4..e87c6be7 100644
--- a/disent/dataset/sampling/_groundtruth__triplet.py
+++ b/disent/dataset/sampling/_groundtruth__triplet.py
@@ -23,12 +23,14 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
import logging
+from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling._base import BaseDisentSampler
+from disent.dataset.util.state_space import StateSpace
from disent.util.math.random import sample_radius
@@ -90,16 +92,16 @@ def __init__(
if swap_chance is not None:
assert 0 <= swap_chance <= 1, f'{swap_chance=} must be in range 0 to 1.'
# dataset variable
- self._data: GroundTruthData
+ self._state_space: Optional[StateSpace]
def _init(self, dataset):
assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
- self._data = dataset
+ self._state_space = dataset.state_space_copy()
# DIFFERING FACTORS
self.p_k_min, self.p_k_max, self.n_k_min, self.n_k_max = self._min_max_from_range(
p_range=self.p_k_range,
n_range=self.n_k_range,
- max_values=self._data.num_factors,
+ max_values=self._state_space.num_factors,
n_sample_mode=self.n_k_sample_mode,
is_radius=False
)
@@ -107,7 +109,7 @@ def _init(self, dataset):
self.p_radius_min, self.p_radius_max, self.n_radius_min, self.n_radius_max = self._min_max_from_range(
p_range=self.p_radius_range,
n_range=self.n_radius_range,
- max_values=self._data.factor_sizes,
+ max_values=self._state_space.factor_sizes,
n_sample_mode=self.n_radius_sample_mode,
is_radius=True
)
@@ -119,9 +121,9 @@ def _init(self, dataset):
def _sample_idx(self, idx):
f0, f1, f2 = self.datapoint_sample_factors_triplet(idx)
return (
- self._data.pos_to_idx(f0),
- self._data.pos_to_idx(f1),
- self._data.pos_to_idx(f2),
+ self._state_space.pos_to_idx(f0),
+ self._state_space.pos_to_idx(f1),
+ self._state_space.pos_to_idx(f2),
)
def datapoint_sample_factors_triplet(self, idx):
@@ -129,7 +131,7 @@ def datapoint_sample_factors_triplet(self, idx):
p_k, n_k = self._sample_num_factors()
p_shared_indices, n_shared_indices = self._sample_shared_indices(p_k, n_k)
# SAMPLE FACTORS - sample, resample and replace shared factors with originals
- anchor_factors = self._data.idx_to_pos(idx)
+ anchor_factors = self._state_space.idx_to_pos(idx)
positive_factors, negative_factors = self._resample_factors(anchor_factors)
positive_factors[p_shared_indices] = anchor_factors[p_shared_indices]
negative_factors[n_shared_indices] = anchor_factors[n_shared_indices]
@@ -189,7 +191,7 @@ def _sample_num_factors(self):
p_k = np.random.randint(self.p_k_min, self.p_k_max + 1)
# sample for negative
if self.n_k_sample_mode == 'offset':
- n_k = np.random.randint(p_k + self.n_k_min, min(p_k + self.n_k_max, self._data.num_factors) + 1)
+ n_k = np.random.randint(p_k + self.n_k_min, min(p_k + self.n_k_max, self._state_space.num_factors) + 1)
elif self.n_k_sample_mode == 'bounded_below':
n_k = np.random.randint(max(p_k, self.n_k_min), self.n_k_max + 1)
elif self.n_k_sample_mode == 'random':
@@ -200,18 +202,18 @@ def _sample_num_factors(self):
return p_k, n_k
def _sample_shared_indices(self, p_k, n_k):
- p_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-p_k, replace=False)
+ p_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-p_k, replace=False)
# sample for negative
if self.n_k_is_shared:
- n_shared_indices = p_shared_indices[:self._data.num_factors-n_k]
+ n_shared_indices = p_shared_indices[:self._state_space.num_factors-n_k]
else:
- n_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-n_k, replace=False)
+ n_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-n_k, replace=False)
# we're done!
return p_shared_indices, n_shared_indices
def _resample_factors(self, anchor_factors):
# sample positive
- positive_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1)
+ positive_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1)
# negative arguments
if self.n_radius_sample_mode == 'offset':
sampled_radius = np.abs(anchor_factors - positive_factors)
@@ -227,7 +229,7 @@ def _resample_factors(self, anchor_factors):
else:
raise KeyError(f'Unknown mode: {self.n_radius_sample_mode=}')
# sample negative
- negative_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=n_r_low, r_high=n_r_high)
+ negative_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=n_r_low, r_high=n_r_high)
# we're done!
return positive_factors, negative_factors
@@ -239,14 +241,14 @@ def _swap_factors(self, anchor_factors, positive_factors, negative_factors):
p_dist = np.sum(np.abs(anchor_factors - positive_factors))
n_dist = np.sum(np.abs(anchor_factors - negative_factors))
elif self._swap_metric == 'manhattan_norm':
- p_dist = np.sum(np.abs((anchor_factors - positive_factors) / np.subtract(self._data.factor_sizes, 1)))
- n_dist = np.sum(np.abs((anchor_factors - negative_factors) / np.subtract(self._data.factor_sizes, 1)))
+ p_dist = np.sum(np.abs((anchor_factors - positive_factors) / np.subtract(self._state_space.factor_sizes, 1)))
+ n_dist = np.sum(np.abs((anchor_factors - negative_factors) / np.subtract(self._state_space.factor_sizes, 1)))
elif self._swap_metric == 'euclidean':
p_dist = np.linalg.norm(anchor_factors - positive_factors)
n_dist = np.linalg.norm(anchor_factors - negative_factors)
elif self._swap_metric == 'euclidean_norm':
- p_dist = np.linalg.norm((anchor_factors - positive_factors) / np.subtract(self._data.factor_sizes, 1))
- n_dist = np.linalg.norm((anchor_factors - negative_factors) / np.subtract(self._data.factor_sizes, 1))
+ p_dist = np.linalg.norm((anchor_factors - positive_factors) / np.subtract(self._state_space.factor_sizes, 1))
+ n_dist = np.linalg.norm((anchor_factors - negative_factors) / np.subtract(self._state_space.factor_sizes, 1))
else:
raise KeyError
# perform swap
diff --git a/disent/dataset/sampling/_groundtruth__walk.py b/disent/dataset/sampling/_groundtruth__walk.py
new file mode 100644
index 00000000..37f5c297
--- /dev/null
+++ b/disent/dataset/sampling/_groundtruth__walk.py
@@ -0,0 +1,128 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+import numpy as np
+
+from disent.dataset.data import GroundTruthData
+from disent.dataset.sampling import BaseDisentSampler
+from disent.dataset.util.state_space import StateSpace
+from disent.util.jit import try_njit
+
+
+# ========================================================================= #
+# Pretend We Are Walking Ground-Truth Factors Randomly #
+# ========================================================================= #
+
+
+class GroundTruthRandomWalkSampler(BaseDisentSampler):
+
+ def uninit_copy(self) -> 'GroundTruthRandomWalkSampler':
+ return GroundTruthRandomWalkSampler(
+ num_samples=self._num_samples,
+ p_dist_max=self._p_dist_max,
+ n_dist_max=self._n_dist_max,
+ )
+
+ def __init__(
+ self,
+ num_samples: int = 3,
+ p_dist_max: int = 8,
+ n_dist_max: int = 32,
+ ):
+ super().__init__(num_samples=num_samples)
+ # checks
+ assert num_samples in {1, 2, 3}, f'num_samples ({repr(num_samples)}) must be 1, 2 or 3'
+ # save hparams
+ self._num_samples = num_samples
+ self._p_dist_max = p_dist_max
+ self._n_dist_max = n_dist_max
+ # dataset variable
+ self._state_space: Optional[StateSpace] = None
+
+ def _init(self, dataset: GroundTruthData):
+ assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}'
+ self._state_space = dataset.state_space_copy()
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+ # Sampling #
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+ def _sample_idx(self, idx) -> Tuple[int, ...]:
+ if self._num_samples == 1:
+ return (idx,)
+ elif self._num_samples == 2:
+ p_dist = np.random.randint(1, self._p_dist_max + 1)
+ pos = _random_walk(idx, p_dist, self._state_space.factor_sizes)
+ return (idx, pos)
+ elif self._num_samples == 3:
+ p_dist = np.random.randint(1, self._p_dist_max + 1)
+ n_dist = np.random.randint(1, self._n_dist_max + 1)
+ pos = _random_walk(idx, p_dist, self._state_space.factor_sizes)
+ neg = _random_walk(pos, n_dist, self._state_space.factor_sizes)
+ return (idx, pos, neg)
+ else:
+ raise RuntimeError
+
+
+# ========================================================================= #
+# Helper #
+# ========================================================================= #
+
+
+def _random_walk(idx: int, dist: int, factor_sizes: np.ndarray) -> int:
+ # random walk
+ pos = np.array(np.unravel_index(idx, factor_sizes), dtype=int) # much faster than StateSpace.idx_to_pos, we don't need checks!
+ for _ in range(dist):
+ _walk_nearby_inplace(pos, factor_sizes)
+ idx = np.ravel_multi_index(pos, factor_sizes) # much faster than StateSpace.pos_to_idx, we don't need checks!
+ # done!
+ return int(idx)
+
+
+@try_njit()
+def _walk_nearby_inplace(pos: np.ndarray, factor_sizes: Sequence[int]) -> NoReturn:
+ # try to shift any single factor by 1 or -1
+ while True:
+ f_idx = np.random.randint(0, len(factor_sizes))
+ cur = pos[f_idx]
+ # walk random factor value
+ if np.random.random() < 0.5:
+ nxt = max(cur - 1, 0)
+ else:
+ nxt = min(cur + 1, factor_sizes[f_idx] - 1)
+ # exit if different
+ if cur != nxt:
+ break
+ # update the position
+ pos[f_idx] = nxt
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/dataset/transform/_augment.py b/disent/dataset/transform/_augment.py
index 69082a1d..7cfb9748 100644
--- a/disent/dataset/transform/_augment.py
+++ b/disent/dataset/transform/_augment.py
@@ -24,6 +24,7 @@
import os
import re
+import warnings
from numbers import Number
from typing import List
from typing import Tuple
@@ -32,12 +33,13 @@
import numpy as np
import torch
-import disent
from disent.nn.modules import DisentModule
from disent.nn.functional import torch_box_kernel_2d
from disent.nn.functional import torch_conv2d_channel_wise_fft
from disent.nn.functional import torch_gaussian_kernel_2d
+import disent.registry as R
+
# ========================================================================= #
# Transforms #
@@ -178,16 +180,22 @@ def _make_kernel(self, shape, device):
# ========================================================================= #
+_NO_ARG = object()
+
+
class FftKernel(DisentModule):
"""
2D Convolve an image
"""
- def __init__(self, kernel: Union[torch.Tensor, str], normalize: bool = True):
+ def __init__(self, kernel: Union[torch.Tensor, str], normalize_mode: str = _NO_ARG):
super().__init__()
+ # deprecation error
+ if normalize_mode is _NO_ARG:
+ raise ValueError(f'default argument for normalize_mode was "sum", this has been deprecated and will change to "none" in future. Please manually override this value!')
# load & save the kernel -- no gradients allowed
self._kernel: torch.Tensor
- self.register_buffer('_kernel', get_kernel(kernel, normalize=normalize), persistent=True)
+ self.register_buffer('_kernel', get_kernel(kernel, normalize_mode=normalize_mode), persistent=True)
self._kernel.requires_grad = False
def forward(self, obs):
@@ -209,11 +217,30 @@ def forward(self, obs):
# ========================================================================= #
-def _normalise_kernel(kernel: torch.Tensor, normalize: bool) -> torch.Tensor:
- if normalize:
- with torch.no_grad():
- return kernel / kernel.sum()
- return kernel
+
+@torch.no_grad()
+def _scale_kernel(kernel: torch.Tensor, mode: Union[bool, str] = 'abssum'):
+ # old normalize mode
+ if isinstance(mode, bool):
+ raise ValueError(f'boolean arguments to `scale_kernel` are deprecated, convert True to "sum" and False to "none", got: {repr(mode)}')
+ # handle the normalize mode
+ if mode == 'sum':
+ return kernel / kernel.sum()
+ elif mode == 'abssum':
+ return kernel / torch.abs(kernel).sum()
+ elif mode == 'possum':
+ return kernel / torch.abs(kernel)[kernel > 0].sum()
+ elif mode == 'negsum':
+ return kernel / torch.abs(kernel)[kernel < 0].sum()
+ elif mode == 'maxsum':
+ return kernel / torch.maximum(
+ torch.abs(kernel)[kernel > 0].sum(),
+ torch.abs(kernel)[kernel < 0].sum(),
+ )
+ elif mode == 'none':
+ return kernel
+ else:
+ raise KeyError(f'invalid scale mode: {repr(mode)}')
def _check_kernel(kernel: torch.Tensor) -> torch.Tensor:
@@ -227,39 +254,10 @@ def _check_kernel(kernel: torch.Tensor) -> torch.Tensor:
return kernel
-_KERNELS = {
- # kernels that do not require arguments, just general factory functions
- # name: class/fn -- with no required args
-}
-
-
-_ARG_KERNELS = [
- # (REGEX, EXAMPLE, FACTORY_FUNC)
- # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group
- # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first.
- (re.compile(r'^(box)_r(\d+)$'), 'box_r31', lambda kern, radius: torch_box_kernel_2d(radius=int(radius))[None, ...]),
- (re.compile(r'^(gau)_r(\d+)$'), 'gau_r31', lambda kern, radius: torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...]),
-]
-
-
# NOTE: this function compliments make_reconstruction_loss in frameworks/helper/reconstructions.py
-def _make_kernel(name: str) -> torch.Tensor:
- if name in _KERNELS:
- # search normal losses!
- return _KERNELS[name]()
- else:
- # regex search kernels, and call with args!
- for r, _, fn in _ARG_KERNELS:
- result = r.search(name)
- if result is not None:
- return fn(*result.groups())
- # we couldn't find anything
- raise KeyError(f'Invalid kernel name: {repr(name)} Examples of argument based kernels include: {[example for _, example, _ in _ARG_KERNELS]}')
-
-
-def make_kernel(name: str, normalize: bool = False):
- kernel = _make_kernel(name)
- kernel = _normalise_kernel(kernel, normalize=normalize)
+def make_kernel(name: str, normalize_mode: str = 'none'):
+ kernel = R.KERNELS[name]
+ kernel = _scale_kernel(kernel, mode=normalize_mode)
kernel = _check_kernel(kernel)
return kernel
@@ -267,21 +265,36 @@ def make_kernel(name: str, normalize: bool = False):
def _get_kernel(name_or_path: str) -> torch.Tensor:
if '/' not in name_or_path:
try:
- return _make_kernel(name_or_path)
+ return R.KERNELS[name_or_path]
except KeyError:
pass
if os.path.isfile(name_or_path):
return torch.load(name_or_path)
- raise KeyError(f'Invalid kernel path or name: {repr(name_or_path)} Examples of argument based kernels include: {[example for _, example, _ in _ARG_KERNELS]}, otherwise specify a valid path to a kernel file save with torch.')
+ raise KeyError(f'Invalid kernel path or name: {repr(name_or_path)} Examples of argument based kernels include: {R.KERNELS.regex_examples}, otherwise specify a valid path to a kernel file save with torch.')
-def get_kernel(kernel: Union[str, torch.Tensor], normalize: bool = False):
+def get_kernel(kernel: Union[str, torch.Tensor], normalize_mode: str = 'none'):
kernel = _get_kernel(kernel) if isinstance(kernel, str) else torch.clone(kernel)
- kernel = _normalise_kernel(kernel, normalize=normalize)
+ kernel = _scale_kernel(kernel, mode=normalize_mode)
kernel = _check_kernel(kernel)
return kernel
+# ========================================================================= #
+# Registered Kernels #
+# ========================================================================= #
+
+
+# we register this in disent.registry
+def _make_box_kernel(radius: str):
+ return torch_box_kernel_2d(radius=int(radius))[None, ...]
+
+
+# we register this in disent.registry
+def _make_gaussian_kernel(radius: str):
+ return torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...]
+
+
# ========================================================================= #
# END #
# ========================================================================= #
diff --git a/disent/dataset/transform/_augment_disent.py b/disent/dataset/transform/_augment_disent.py
index 5cc305f0..0ed03100 100644
--- a/disent/dataset/transform/_augment_disent.py
+++ b/disent/dataset/transform/_augment_disent.py
@@ -22,6 +22,10 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+from typing import Callable
+from typing import Optional
+import torch
+
# ========================================================================= #
# Augment #
@@ -34,7 +38,11 @@ class DisentDatasetTransform(object):
datasets from: disent.dataset.groundtruth
"""
- def __init__(self, transform=None, transform_targ=None):
+ def __init__(
+ self,
+ transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ transform_targ: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ ):
self.transform = transform
self.transform_targ = transform_targ
diff --git a/disent/dataset/transform/functional.py b/disent/dataset/transform/functional.py
index e3ebf447..6bc24ffa 100644
--- a/disent/dataset/transform/functional.py
+++ b/disent/dataset/transform/functional.py
@@ -88,6 +88,12 @@ def check_tensor(
# ========================================================================= #
+def _is_size_different(obs: Obs, size: SizeType):
+ h, w = (size, size) if isinstance(size, int) else size
+ H, W = (obs.height, obs.width) if isinstance(obs, Image) else obs.shape[:2]
+ return (H != h) or (W != w)
+
+
def to_img_tensor_u8(
obs: Obs,
size: Optional[SizeType] = None,
@@ -101,7 +107,7 @@ def to_img_tensor_u8(
3. move channels to first dim (H, W, C) -> (C, H, W)
"""
# resize image
- if size is not None:
+ if (size is not None) and _is_size_different(obs, size):
if not isinstance(obs, Image):
obs = F_tv.to_pil_image(obs)
obs = F_tv.resize(obs, size=size)
@@ -139,11 +145,27 @@ def to_img_tensor_f32(
5. normalize using mean and std, values might thus be outside of the range [0, 1]
"""
# resize image
- if size is not None:
+ if (size is not None) and _is_size_different(obs, size):
if not isinstance(obs, Image):
obs = F_tv.to_pil_image(obs)
obs = F_tv.resize(obs, size=size)
# transform to tensor, add missing dims & move channel dim to front
+ # TODO: this should be replaced with custom logic, this is quite slow...
+ # - benchmarks show that doing conversions as numpy first, and then using torch.from_numpy is faster!
+ # - eg. SmallNorb64Data
+ # `to_tensor(item) # 27547.81it/s
+ # `torch.from_numpy(np.transpose(item.astype('float32') / 255, [2, 0, 1])) # 53987.90it/s
+ # `torch.from_numpy((item.astype('float32') / 255).transpose([2, 0, 1])) # 66544.00it/s
+ # `torch.from_numpy(item.astype('float32').transpose([2, 0, 1]) / 255) # 66511.62it/s
+ # `torch.from_numpy(item.transpose([2, 0, 1]).astype('float32') / 255) # 66133.27it/s
+ # - eg. Cars3d64Data
+ # `to_tensor(item) # 13810.46it/s
+ # `torch.from_numpy(np.transpose(item.astype('float32') / 255, [2, 0, 1])) # 32258.03it/s
+ # `torch.from_numpy((item.astype('float32') / 255).transpose([2, 0, 1])) # 37861.46it/s
+ # `torch.from_numpy(item.astype('float32').transpose([2, 0, 1]) / 255) # 33034.21it/s
+ # `torch.from_numpy(item.transpose([2, 0, 1]).astype('float32') / 255) # 32883.32it/s
+ # - INVESTIGATE: if transpose is used, and then from_numpy is called, that references the original memory? It
+ # might then be slower to convolve this data? Speed benefits could be negated? A copy might be better?
obs = F_tv.to_tensor(obs)
# checks
assert obs.ndim == 3, f'obs has does not have 3 dimensions, got: {obs.ndim} for shape: {obs.shape}'
diff --git a/disent/dataset/util/datafile.py b/disent/dataset/util/datafile.py
index 7e00b892..3c8430b5 100644
--- a/disent/dataset/util/datafile.py
+++ b/disent/dataset/util/datafile.py
@@ -27,6 +27,7 @@
from typing import Callable
from typing import Dict
from typing import final
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -95,7 +96,7 @@ def wrapped(out_file):
self._prepare(out_dir=out_dir, out_file=out_file)
return wrapped()
- def _prepare(self, out_dir: str, out_file: str) -> str:
+ def _prepare(self, out_dir: str, out_file: str) -> NoReturn:
# TODO: maybe raise a FileNotFoundError or a HashError instead?
raise NotImplementedError
diff --git a/disent/dataset/util/hdf5.py b/disent/dataset/util/hdf5.py
index eec1058b..0d4f17fe 100644
--- a/disent/dataset/util/hdf5.py
+++ b/disent/dataset/util/hdf5.py
@@ -385,6 +385,7 @@ def add_dataset_from_gt_data(
from disent.dataset import DisentDataset
from disent.dataset.data import GroundTruthData
# get dataset
+ # TODO: we should not automatically handle this extraction... The transform could be missing on the dataset?
if isinstance(data, DisentDataset): gt_data = data.gt_data
elif isinstance(data, GroundTruthData): gt_data = data
else: raise TypeError(f'invalid data type: {type(data)}, must be {DisentDataset} or {GroundTruthData}')
diff --git a/disent/dataset/util/npz.py b/disent/dataset/util/npz.py
new file mode 100644
index 00000000..c7e83995
--- /dev/null
+++ b/disent/dataset/util/npz.py
@@ -0,0 +1,77 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import numpy as np
+from tqdm import tqdm
+from disent.util.inout.files import AtomicSaveFile
+
+
+# ========================================================================= #
+# Save Numpy Files #
+# ========================================================================= #
+
+
+def save_dataset_array(array: np.ndarray, out_file: str, overwrite: bool = False, save_key: str = 'images'):
+ assert array.ndim == 4, f'invalid array shape, got: {array.shape}, must be: (N, H, W, C)'
+ assert array.dtype == 'uint8', f'invalid array dtype, got: {array.dtype}, must be: "uint8"'
+ # save the data
+ with AtomicSaveFile(out_file, overwrite=overwrite) as temp_file:
+ np.savez_compressed(temp_file, **{save_key: array})
+
+
+def save_resized_dataset_array(array: np.ndarray, out_file: str, size: int = 64, overwrite: bool = False, save_key: str = 'images', progress: bool = True):
+ import torchvision.transforms.functional as F_tv
+ from disent.dataset.data import ArrayDataset
+ # checks
+ assert out_file.endswith('.npz'), f'The output file must end with the extension: ".npz", got: {repr(out_file)}'
+ # Get the transform -- copied from: ToImgTensorF32 / ToImgTensorU8
+ def transform(obs):
+ H, W, C = obs.shape
+ obs = F_tv.to_pil_image(obs)
+ obs = F_tv.resize(obs, size=[size, size])
+ obs = np.array(obs)
+ # add removed dimension!
+ if obs.ndim == 2:
+ obs = obs[:, :, None]
+ assert obs.shape == (size, size, C)
+ return obs
+ # load the converted cars3d data ?x128x128x3
+ assert array.ndim == 4, f'invalid array shape, got: {array.shape}, must be: (N, H, W, C)'
+ assert array.dtype == 'uint8', f'invalid array dtype, got: {array.dtype}, must be: "uint8"'
+ N, H, W, C = array.shape
+ data = ArrayDataset(array, transform=transform)
+ # save the data
+ with AtomicSaveFile(out_file, overwrite=overwrite) as temp_file:
+ # resize the cars3d data
+ idxs = tqdm(range(N), desc = 'converting') if progress else range(N)
+ converted = np.zeros([N, size, size, C], dtype='uint8')
+ for i in idxs:
+ converted[i, ...] = data[i]
+ # save the data
+ np.savez_compressed(temp_file, **{save_key: converted})
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/dataset/util/state_space.py b/disent/dataset/util/state_space.py
index 5c72168f..e68d0d90 100644
--- a/disent/dataset/util/state_space.py
+++ b/disent/dataset/util/state_space.py
@@ -54,9 +54,12 @@ class StateSpace(LengthIter):
def __init__(self, factor_sizes: Sequence[int], factor_names: Optional[Sequence[str]] = None):
super().__init__()
- # dimension
+ # dimension: [read only]
self.__factor_sizes = np.array(factor_sizes)
self.__factor_sizes.flags.writeable = False
+ # multipliers: [read only]
+ self.__factor_multipliers = _dims_multipliers(self.__factor_sizes)
+ self.__factor_multipliers.flags.writeable = False
# total permutations
self.__size = int(np.prod(factor_sizes))
# factor names
@@ -96,6 +99,21 @@ def factor_names(self) -> Tuple[str, ...]:
"""A list of names of factors handled by this state space"""
return self.__factor_names
+ @property
+ def factor_multipliers(self) -> np.ndarray:
+ """
+ The cumulative product of the factor_sizes used to convert indices to positions, and positions to indices.
+ - The highest values is at the front, the lowest is at the end always being 1.
+ - The size of this vector is: num_factors + 1
+
+ Formulas:
+ * Use broadcasting to get positions:
+ pos = (idx[..., None] % muls[:-1]) // muls[1:]
+ * Use broadcasting to get indices
+ idx = np.sum(pos * muls[1:], axis=-1)
+ """
+ return self.__factor_multipliers
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Factor Helpers #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
@@ -137,6 +155,8 @@ def pos_to_idx(self, positions) -> np.ndarray:
Convert a position to an index (or convert a list of positions to a list of indices)
- positions are lists of integers, with each element < their corresponding factor size
- indices are integers < size
+
+ TODO: can factor_multipliers be used to speed this up?
"""
positions = np.moveaxis(positions, source=-1, destination=0)
return np.ravel_multi_index(positions, self.__factor_sizes)
@@ -146,6 +166,8 @@ def idx_to_pos(self, indices) -> np.ndarray:
Convert an index to a position (or convert a list of indices to a list of positions)
- indices are integers < size
- positions are lists of integers, with each element < their corresponding factor size
+
+ TODO: can factor_multipliers be used to speed this up?
"""
positions = np.array(np.unravel_index(indices, self.__factor_sizes))
return np.moveaxis(positions, source=0, destination=-1)
@@ -249,14 +271,14 @@ def _get_f_idx_and_factors_and_size(self, f_idx: int = None, base_factors=None,
# return everything
return f_idx, base_factors, num
- def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode='interval') -> np.ndarray:
+ def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode: str = 'interval', start_index: int = 0) -> np.ndarray:
"""
Sample a single random factor traversal along the
given factor index, starting from some random base sample.
"""
f_idx, base_factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, base_factors=base_factors, num=num)
# generate traversal
- base_factors[:, f_idx] = get_idx_traversal(self.factor_sizes[f_idx], num_frames=num, mode=mode)
+ base_factors[:, f_idx] = get_idx_traversal(self.factor_sizes[f_idx], num_frames=num, mode=mode, start_index=start_index)
# return factors (num_frames, num_factors)
return base_factors
@@ -267,7 +289,7 @@ def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, n
@lru_cache()
-def _get_step_size(factor_sizes, f_idx):
+def _get_step_size(factor_sizes, f_idx: int):
# check values
assert f_idx >= 0
assert f_idx < len(factor_sizes)
@@ -279,6 +301,20 @@ def _get_step_size(factor_sizes, f_idx):
return int(np.ravel_multi_index(pos, factor_sizes))
+def _dims_multipliers(factor_sizes: np.ndarray) -> np.ndarray:
+ factor_sizes = np.array(factor_sizes)
+ assert factor_sizes.ndim == 1
+ return np.append(np.cumprod(factor_sizes[::-1])[::-1], 1)
+
+
+# @try_njit
+# def _idx_to_pos(idxs, dims_mul):
+# factors = np.expand_dims(np.array(idxs, dtype='int'), axis=-1)
+# factors = factors % dims_mul[:-1]
+# factors //= dims_mul[1:]
+# return factors
+
+
# ========================================================================= #
# Hidden State Space #
# ========================================================================= #
diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py
index 08be5c04..0ba2985f 100644
--- a/disent/dataset/util/stats.py
+++ b/disent/dataset/util/stats.py
@@ -29,8 +29,6 @@
import torch
from torch.utils.data import DataLoader
-from disent.util.function import wrapped_partial
-
# ========================================================================= #
# COMPUTE DATASET STATS #
@@ -86,15 +84,17 @@ def compute_data_mean_std(
if __name__ == '__main__':
- def main(progress=False):
+ def main(progress=True, num_workers=0, batch_size=256): # try changing workers to zero on MacOS
from disent.dataset import data
from disent.dataset.transform import ToImgTensorF32
for data_cls in [
# groundtruth -- impl
data.Cars3dData,
- data.DSpritesData,
+ data.Cars3d64Data,
data.SmallNorbData,
+ data.SmallNorb64Data,
+ data.DSpritesData,
data.Shapes3dData,
# groundtruth -- impl synthetic
data.XYObjectData,
@@ -113,7 +113,7 @@ def main(progress=False):
# Most common standardized way of computing the mean and std over observations
# resized to 64px in size of dtype float32 in the range [0, 1].
data = data_cls(transform=ToImgTensorF32(size=64), **kwargs)
- mean, std = compute_data_mean_std(data, progress=progress)
+ mean, std = compute_data_mean_std(data, progress=progress, num_workers=num_workers, batch_size=batch_size)
# results!
print(f'{data.__class__.__name__} - {data.name} - {kwargs}:\n mean: {mean.tolist()}\n std: {std.tolist()}')
diff --git a/disent/frameworks/_framework.py b/disent/frameworks/_framework.py
index 3b51557d..b1d4eeb3 100644
--- a/disent/frameworks/_framework.py
+++ b/disent/frameworks/_framework.py
@@ -22,6 +22,7 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+import logging
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import fields
@@ -32,15 +33,14 @@
from typing import final
from typing import Optional
from typing import Tuple
-from typing import Type
from typing import Union
-import logging
import torch
from disent import registry
-from disent.schedule import Schedule
from disent.nn.modules import DisentLightningModule
+from disent.schedule import Schedule
+from disent.util.imports import import_obj
log = logging.getLogger(__name__)
@@ -83,7 +83,7 @@ class DisentFramework(DisentConfigurable, DisentLightningModule):
@dataclass
class cfg(DisentConfigurable.cfg):
# optimizer config
- optimizer: Union[str, Type[torch.optim.Optimizer]] = 'adam'
+ optimizer: Union[str] = 'adam' # name in the registry, eg. `adam` OR the path to an optimizer eg. `torch.optim.Adam`
optimizer_kwargs: Optional[Dict[str, Union[str, float, int]]] = None
def __init__(
@@ -94,32 +94,54 @@ def __init__(
):
# save the config values to the class
super().__init__(cfg=cfg)
- # get the optimizer
- if isinstance(self.cfg.optimizer, str):
- if self.cfg.optimizer not in registry.OPTIMIZERS:
- raise KeyError(f'invalid optimizer: {repr(self.cfg.optimizer)}, valid optimizers are: {sorted(registry.OPTIMIZER)}, otherwise pass a torch.optim.Optimizer class instead.')
- self.cfg.optimizer = registry.OPTIMIZERS[self.cfg.optimizer]
- # check the optimizer values
- assert callable(self.cfg.optimizer)
- assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(self.cfg.optimizer_kwargs)}'
- # set default values for optimizer
- if self.cfg.optimizer_kwargs is None:
- self.cfg.optimizer_kwargs = dict()
- if 'lr' not in self.cfg.optimizer_kwargs:
- self.cfg.optimizer_kwargs['lr'] = 1e-3
- log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`')
+ # check the optimizer
+ self.cfg.optimizer = self._check_optimizer(self.cfg.optimizer)
+ self.cfg.optimizer_kwargs = self._check_optimizer_kwargs(self.cfg.optimizer_kwargs)
# batch augmentations may not be implemented as dataset
# transforms so we can apply these on the GPU instead
- assert callable(batch_augment) or (batch_augment is None)
+ assert callable(batch_augment) or (batch_augment is None), f'invalid batch_augment: {repr(batch_augment)}, must be callable or `None`'
self._batch_augment = batch_augment
# schedules
# - maybe add support for schedules in the config?
self._registered_schedules = set()
self._active_schedules: Dict[str, Tuple[Any, Schedule]] = {}
+ @staticmethod
+ def _check_optimizer(optimizer: str):
+ if not isinstance(optimizer, str):
+ raise TypeError(f'invalid optimizer: {repr(optimizer)}, must be a `str`')
+ # check that the optimizer has been registered
+ # otherwise check that the optimizer class can be imported instead
+ if optimizer not in registry.OPTIMIZERS:
+ try:
+ import_obj(optimizer)
+ except ImportError:
+ raise KeyError(f'invalid optimizer: {repr(optimizer)}, valid optimizers are: {sorted(registry.OPTIMIZERS)}, or an import path to an optimizer, eg. `torch.optim.Adam`')
+ # return the updated values!
+ return optimizer
+
+ @staticmethod
+ def _check_optimizer_kwargs(optimizer_kwargs: Optional[dict]):
+ # check the optimizer kwargs
+ assert isinstance(optimizer_kwargs, dict) or (optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(optimizer_kwargs)}'
+ # get default kwargs OR copy
+ optimizer_kwargs = dict() if (optimizer_kwargs is None) else dict(optimizer_kwargs)
+ # set default values
+ if 'lr' not in optimizer_kwargs:
+ optimizer_kwargs['lr'] = 1e-3
+ log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`')
+ # return the updated values
+ return optimizer_kwargs
+
@final
def configure_optimizers(self):
- optimizer_cls = self.cfg.optimizer
+ # get the optimizer
+ # 1. first check if the name has been registered
+ # 2. then check if the name can be imported
+ if self.cfg.optimizer in registry.OPTIMIZERS:
+ optimizer_cls = registry.OPTIMIZERS[self.cfg.optimizer]
+ else:
+ optimizer_cls = import_obj(self.cfg.optimizer)
# check that we can call the optimizer
if not callable(optimizer_cls):
raise TypeError(f'unsupported optimizer type: {type(optimizer_cls)}')
@@ -133,31 +155,24 @@ def configure_optimizers(self):
@final
def _compute_loss_step(self, batch, batch_idx, update_schedules: bool):
- # augment batch with GPU support
- if self._batch_augment is not None:
- batch = self._batch_augment(batch)
- # update the config values based on registered schedules
- if update_schedules:
- # TODO: how do we handle this in the case of the validation and test step? I think this
- # might still give the wrong results as this is based on the trainer.global_step which
- # may be incremented by these steps.
- self._update_config_from_schedules()
- # compute loss
- loss, logs_dict = self.do_training_step(batch, batch_idx)
- # check returned values
- assert 'loss' not in logs_dict
- self._assert_valid_loss(loss)
- # log returned values
- logs_dict['loss'] = loss
- self.log_dict(logs_dict)
- # return loss
- return loss
-
- @final
- def training_step(self, batch, batch_idx):
- """This is a pytorch-lightning function that should return the computed loss"""
try:
- return self._compute_loss_step(batch, batch_idx, update_schedules=True)
+ # augment batch with GPU support
+ if self._batch_augment is not None:
+ batch = self._batch_augment(batch)
+ # update the config values based on registered schedules
+ if update_schedules:
+ # TODO: how do we handle this in the case of the validation and test step? I think this
+ # might still give the wrong results as this is based on the trainer.global_step which
+ # may be incremented by these steps.
+ self._update_config_from_schedules()
+ # compute loss
+ # TODO: move logging into child frameworks?
+ loss = self.do_training_step(batch, batch_idx)
+ # check loss values
+ self._assert_valid_loss(loss)
+ self.log('loss', float(loss), prog_bar=True)
+ # return loss
+ return loss
except Exception as e: # pragma: no cover
# call in all the child processes for the best chance of clearing this...
# remove callbacks from trainer so we aren't stuck running forever!
@@ -180,11 +195,27 @@ def test_step(self, batch, batch_idx):
"""
return self._compute_loss_step(batch, batch_idx, update_schedules=False)
+ @final
+ def training_step(self, batch, batch_idx):
+ """This is a pytorch-lightning function that should return the computed loss"""
+ return self._compute_loss_step(batch, batch_idx, update_schedules=True)
+
+ def validation_step(self, batch, batch_idx):
+ """
+ TODO: how do we handle the schedule in this case?
+ """
+ return self._compute_loss_step(batch, batch_idx, update_schedules=False)
+
+ def test_step(self, batch, batch_idx):
+ """
+ TODO: how do we handle the schedule in this case?
+ """
+ return self._compute_loss_step(batch, batch_idx, update_schedules=False)
+
@final
def _assert_valid_loss(self, loss):
- if self.trainer.terminate_on_nan:
- if torch.isnan(loss) or torch.isinf(loss):
- raise ValueError('The returned loss is nan or inf')
+ if torch.isnan(loss) or torch.isinf(loss):
+ raise ValueError('The returned loss is nan or inf')
if loss > 1e+20:
raise ValueError(f'The returned loss: {loss:.2e} is out of bounds: > {1e+20:.0e}')
@@ -192,7 +223,7 @@ def forward(self, batch) -> torch.Tensor: # pragma: no cover
"""this function should return the single final output of the model, including the final activation"""
raise NotImplementedError
- def do_training_step(self, batch, batch_idx) -> Tuple[torch.Tensor, Dict[str, Union[Number, torch.Tensor]]]: # pragma: no cover
+ def do_training_step(self, batch, batch_idx) -> torch.Tensor: # pragma: no cover
"""
should return a dictionary of items to log with the key 'train_loss'
as the variable to minimize
diff --git a/disent/frameworks/ae/_ae_mixin.py b/disent/frameworks/ae/_ae_mixin.py
new file mode 100644
index 00000000..7b1584d0
--- /dev/null
+++ b/disent/frameworks/ae/_ae_mixin.py
@@ -0,0 +1,147 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2021 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+from dataclasses import dataclass
+from typing import Dict
+from typing import final
+from typing import Tuple
+
+import torch
+
+from disent.frameworks import DisentFramework
+from disent.frameworks.helper.reconstructions import make_reconstruction_loss
+from disent.frameworks.helper.reconstructions import ReconLossHandler
+from disent.model import AutoEncoder
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# framework_vae #
+# ========================================================================= #
+
+
+class _AeAndVaeMixin(DisentFramework):
+ """
+ Base class containing common logic for both the Ae and Vae classes.
+ -- This is private because handling of both classes needs to be conducted differently.
+ An instance of a Vae should not be an instance of an Ae and vice-versa
+ """
+
+ @dataclass
+ class cfg(DisentFramework.cfg):
+ recon_loss: str = 'mse'
+ # multiple reduction modes exist for the various loss components.
+ # - 'sum': sum over the entire batch
+ # - 'mean': mean over the entire batch
+ # - 'mean_sum': sum each observation, returning the mean sum over the batch
+ loss_reduction: str = 'mean'
+ # disable various components
+ detach_decoder: bool = False
+ disable_rec_loss: bool = False
+ disable_aug_loss: bool = False
+
+ # --------------------------------------------------------------------- #
+ # AE/VAE Attributes #
+ # --------------------------------------------------------------------- #
+
+ @property
+ def REQUIRED_Z_MULTIPLIER(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def REQUIRED_OBS(self) -> int:
+ raise NotImplementedError
+
+ @final
+ @property
+ def recon_handler(self) -> ReconLossHandler:
+ return self.__recon_handler
+
+ # --------------------------------------------------------------------- #
+ # AE/VAE Init #
+ # --------------------------------------------------------------------- #
+
+ # attributes provided by this class and initialised in _init_ae_mixin
+ _model: AutoEncoder
+ __recon_handler: ReconLossHandler
+
+ def _init_ae_mixin(self, model: AutoEncoder):
+ # vae model
+ self._model = model
+ # check the model
+ assert isinstance(self._model, AutoEncoder), f'model must be an instance of {AutoEncoder.__name__}, got: {type(model)}'
+ assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}'
+ # recon loss & activation fn
+ self.__recon_handler: ReconLossHandler = make_reconstruction_loss(self.cfg.recon_loss, reduction=self.cfg.loss_reduction)
+
+ # --------------------------------------------------------------------- #
+ # AE/VAE Training Step Helper #
+ # --------------------------------------------------------------------- #
+
+ @final
+ def _get_xs_and_targs(self, batch: Dict[str, Tuple[torch.Tensor, ...]], batch_idx) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
+ xs_targ = batch['x_targ']
+ if 'x' not in batch:
+ # TODO: re-enable this warning but only ever print once!
+ # warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ')
+ xs = xs_targ
+ else:
+ xs = batch['x']
+ # check that we have the correct number of inputs
+ if (len(xs) != self.REQUIRED_OBS) or (len(xs_targ) != self.REQUIRED_OBS):
+ log.warning(f'batch len(xs)={len(xs)} and len(xs_targ)={len(xs_targ)} observation count mismatch, requires: {self.REQUIRED_OBS}')
+ # done
+ return xs, xs_targ
+
+ # --------------------------------------------------------------------- #
+ # AE/VAE Model Utility Functions (Visualisation) #
+ # --------------------------------------------------------------------- #
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Get the deterministic latent representation (useful for visualisation)"""
+ raise NotImplementedError
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ """Decode latent vector z into reconstruction x_recon (useful for visualisation)"""
+ raise NotImplementedError
+
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
+ """Feed through the full deterministic model (useful for visualisation)"""
+ raise NotImplementedError
+
+ # --------------------------------------------------------------------- #
+ # AE/VAE Model Utility Functions (Training) #
+ # --------------------------------------------------------------------- #
+
+ def decode_partial(self, z: torch.Tensor) -> torch.Tensor:
+ """Decode latent vector z into partial reconstructions that exclude the final activation if there is one."""
+ raise NotImplementedError
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py
index 8e7d2541..f8f4a204 100644
--- a/disent/frameworks/ae/_unsupervised__ae.py
+++ b/disent/frameworks/ae/_unsupervised__ae.py
@@ -23,7 +23,6 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
import logging
-import warnings
from dataclasses import dataclass
from numbers import Number
from typing import Any
@@ -35,23 +34,18 @@
import torch
-from disent.frameworks import DisentFramework
-from disent.frameworks.helper.reconstructions import make_reconstruction_loss
-from disent.frameworks.helper.reconstructions import ReconLossHandler
+from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin
from disent.frameworks.helper.util import detach_all
from disent.model import AutoEncoder
from disent.util.iters import map_all
-log = logging.getLogger(__name__)
-
-
# ========================================================================= #
# framework_vae #
# ========================================================================= #
-class Ae(DisentFramework):
+class Ae(_AeAndVaeMixin):
"""
Basic Auto Encoder
------------------
@@ -73,56 +67,23 @@ class Ae(DisentFramework):
* `compute_ave_recon_loss`
"""
+ # override
REQUIRED_Z_MULTIPLIER = 1
REQUIRED_OBS = 1
@dataclass
- class cfg(DisentFramework.cfg):
- recon_loss: str = 'mse'
- # multiple reduction modes exist for the various loss components.
- # - 'sum': sum over the entire batch
- # - 'mean': mean over the entire batch
- # - 'mean_sum': sum each observation, returning the mean sum over the batch
- loss_reduction: str = 'mean'
- # disable various components
- disable_decoder: bool = False
- disable_rec_loss: bool = False
- disable_aug_loss: bool = False
+ class cfg(_AeAndVaeMixin.cfg):
+ pass
def __init__(self, model: AutoEncoder, cfg: cfg = None, batch_augment=None):
super().__init__(cfg=cfg, batch_augment=batch_augment)
- # vae model
- self._model = model
- # check the model
- assert isinstance(self._model, AutoEncoder)
- assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}'
- # recon loss & activation fn
- self.__recon_handler: ReconLossHandler = make_reconstruction_loss(self.cfg.recon_loss, reduction=self.cfg.loss_reduction)
-
- @final
- @property
- def recon_handler(self) -> ReconLossHandler:
- return self.__recon_handler
+ # initialise the auto-encoder mixin (recon handler, model, enc, dec, etc.)
+ self._init_ae_mixin(model=model)
# --------------------------------------------------------------------- #
# AE Training Step -- Overridable #
# --------------------------------------------------------------------- #
- @final
- def _get_xs_and_targs(self, batch: Dict[str, Tuple[torch.Tensor, ...]], batch_idx) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
- xs_targ = batch['x_targ']
- if 'x' not in batch:
- # TODO: re-enable this warning but only ever print once!
- # warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ')
- xs = xs_targ
- else:
- xs = batch['x']
- # check that we have the correct number of inputs
- if (len(xs) != self.REQUIRED_OBS) or (len(xs_targ) != self.REQUIRED_OBS):
- log.warning(f'batch len(xs)={len(xs)} and len(xs_targ)={len(xs_targ)} observation count mismatch, requires: {self.REQUIRED_OBS}')
- # done
- return xs, xs_targ
-
@final
def do_training_step(self, batch, batch_idx):
xs, xs_targ = self._get_xs_and_targs(batch, batch_idx)
@@ -131,10 +92,10 @@ def do_training_step(self, batch, batch_idx):
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
# latent variables
zs = map_all(self.encode, xs)
- # intercept latent variables
+ # [HOOK] intercept latent variables
zs, logs_intercept_zs = self.hook_ae_intercept_zs(zs)
# reconstruct without the final activation
- xs_partial_recon = map_all(self.decode_partial, detach_all(zs, if_=self.cfg.disable_decoder))
+ xs_partial_recon = map_all(self.decode_partial, detach_all(zs, if_=self.cfg.detach_decoder))
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
# LOSS
@@ -149,14 +110,21 @@ def do_training_step(self, batch, batch_idx):
if not self.cfg.disable_aug_loss: loss += aug_loss
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
- # return values
- return loss, {
+ # log general
+ self.log_dict({
**logs_intercept_zs,
**logs_recon,
**logs_aug,
- 'recon_loss': recon_loss,
- 'aug_loss': aug_loss,
- }
+ })
+
+ # log progress bar
+ self.log_dict({
+ 'recon_loss': float(recon_loss),
+ 'aug_loss': float(aug_loss),
+ }, prog_bar=True)
+
+ # return values
+ return loss
# --------------------------------------------------------------------- #
# Overrideable #
diff --git a/disent/frameworks/helper/latent_distributions.py b/disent/frameworks/helper/latent_distributions.py
index ad7ac0a9..cabf61c4 100644
--- a/disent/frameworks/helper/latent_distributions.py
+++ b/disent/frameworks/helper/latent_distributions.py
@@ -32,6 +32,7 @@
from torch.distributions import Laplace
from torch.distributions import Normal
+import disent.registry as R
from disent.frameworks.helper.util import compute_ave_loss
from disent.nn.loss.kl import kl_loss
from disent.nn.loss.reduction import loss_reduction
@@ -161,17 +162,8 @@ def encoding_to_dists(self, raw_z: Tuple[torch.Tensor, ...]) -> Tuple[Laplace, L
# ========================================================================= #
-_LATENT_HANDLERS = {
- 'normal': LatentDistsHandlerNormal,
- 'laplace': LatentDistsHandlerLaplace,
-}
-
-
def make_latent_distribution(name: str, kl_mode: str, reduction: str) -> LatentDistsHandler:
- try:
- cls = _LATENT_HANDLERS[name]
- except KeyError:
- raise KeyError(f'unknown vae distribution name: {repr(name)}, must be one of: {sorted(_LATENT_HANDLERS.keys())}')
+ cls = R.LATENT_HANDLERS[name]
# make instance
return cls(kl_mode=kl_mode, reduction=reduction)
diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py
index 723d764f..7da710bb 100644
--- a/disent/frameworks/helper/reconstructions.py
+++ b/disent/frameworks/helper/reconstructions.py
@@ -22,23 +22,21 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-import re
import warnings
from typing import final
-from typing import List
from typing import Sequence
-from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
-from disent import registry
+import disent.registry as R
+from disent.dataset.transform import FftKernel
from disent.frameworks.helper.util import compute_ave_loss
from disent.nn.loss.reduction import batch_loss_reduction
from disent.nn.loss.reduction import loss_reduction
from disent.nn.modules import DisentModule
-from disent.dataset.transform import FftKernel
+from disent.util.deprecate import deprecated
# ========================================================================= #
@@ -239,17 +237,30 @@ def compute_unreduced_loss(self, x_recon, x_targ):
# ========================================================================= #
+_NO_ARG = object()
+
+
class AugmentedReconLossHandler(ReconLossHandler):
- def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], wrap_weight=1.0, aug_weight=1.0):
+ def __init__(
+ self,
+ recon_loss_handler: ReconLossHandler,
+ kernel: Union[str, torch.Tensor],
+ wrap_weight: float = 1.0,
+ aug_weight: float = 1.0,
+ normalize_mode: str = _NO_ARG
+ ):
super().__init__(reduction=recon_loss_handler._reduction)
# save variables
self._recon_loss_handler = recon_loss_handler
# must be a recon loss handler, but cannot nest augmented handlers
assert isinstance(recon_loss_handler, ReconLossHandler)
assert not isinstance(recon_loss_handler, AugmentedReconLossHandler)
+ # deprecation error
+ if normalize_mode is _NO_ARG:
+ raise ValueError(f'default argument for normalize_mode was "sum", this has been deprecated and will change to "none" in future. Please manually override this value!')
# load the kernel
- self._kernel = FftKernel(kernel=kernel, normalize=True)
+ self._kernel = FftKernel(kernel=kernel, normalize_mode=normalize_mode)
# kernel weighting
assert 0 <= wrap_weight, f'loss_weight must be in the range [0, inf) but received: {repr(wrap_weight)}'
assert 0 <= aug_weight, f'kern_weight must be in the range [0, inf) but received: {repr(aug_weight)}'
@@ -273,30 +284,31 @@ def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_t
# ========================================================================= #
# Registry & Factory #
+# TODO: add ability to register parameterized reconstruction losses
# ========================================================================= #
-# TODO: add ability to register parameterized reconstruction losses
-_ARG_RECON_LOSSES: List[Tuple[re.Pattern, str, callable]] = [
- # (REGEX, EXAMPLE, FACTORY_FUNC)
- # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group
- # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first.
-]
+def _make_aug_recon_loss_l_w_n(loss: str, kern: str, loss_weight: str, kernel_weight: str, normalize_mode: str):
+ def _loss(reduction: str):
+ return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=float(loss_weight), aug_weight=float(kernel_weight), normalize_mode=normalize_mode)
+ return _loss
+
+
+def _make_aug_recon_loss_l1_w1_n(loss: str, kern: str, normalize_mode: str):
+ def _loss(reduction: str):
+ return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=1.0, aug_weight=1.0, normalize_mode=normalize_mode)
+ return _loss
+
+
+def _make_aug_recon_loss_l1_w1_nnone(loss: str, kern: str):
+ def _loss(reduction: str):
+ return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=1.0, aug_weight=1.0, normalize_mode='none')
+ return _loss
# NOTE: this function compliments make_kernel in transform/_augment.py
def make_reconstruction_loss(name: str, reduction: str) -> ReconLossHandler:
- if name in registry.RECON_LOSSES:
- # search normal losses!
- return registry.RECON_LOSSES[name](reduction)
- else:
- # regex search losses, and call with args!
- for r, _, fn in _ARG_RECON_LOSSES:
- result = r.search(name)
- if result is not None:
- return fn(reduction, *result.groups())
- # we couldn't find anything
- raise KeyError(f'Invalid vae reconstruction loss: {repr(name)} Valid losses include: {sorted(registry.RECON_LOSSES)}, examples of additional argument based losses include: {[example for _, example, _ in _ARG_RECON_LOSSES]}')
+ return R.RECON_LOSSES[name](reduction=reduction)
# ========================================================================= #
diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py
index 366bd4e9..1ad85b33 100644
--- a/disent/frameworks/vae/_unsupervised__vae.py
+++ b/disent/frameworks/vae/_unsupervised__vae.py
@@ -27,21 +27,17 @@
from typing import Any
from typing import Dict
from typing import final
-from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import torch
from torch.distributions import Distribution
-from torch.distributions import Laplace
-from torch.distributions import Normal
-from disent.frameworks.ae._unsupervised__ae import Ae
+from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin
from disent.frameworks.helper.latent_distributions import LatentDistsHandler
from disent.frameworks.helper.latent_distributions import make_latent_distribution
from disent.frameworks.helper.util import detach_all
-
from disent.util.iters import map_all
@@ -50,7 +46,7 @@
# ========================================================================= #
-class Vae(Ae):
+class Vae(_AeAndVaeMixin):
"""
Variational Auto Encoder
https://arxiv.org/abs/1312.6114
@@ -88,22 +84,23 @@ class Vae(Ae):
Vae(hook_compute_ave_aug_loss=TripletLoss(), required_obs=3)
"""
- # override required z from AE
+ # overrides
REQUIRED_Z_MULTIPLIER = 2
REQUIRED_OBS = 1
@dataclass
- class cfg(Ae.cfg):
+ class cfg(_AeAndVaeMixin.cfg):
# latent distribution settings
latent_distribution: str = 'normal'
kl_loss_mode: str = 'direct'
# disable various components
disable_reg_loss: bool = False
- disable_posterior_scale: Optional[float] = None
def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
# required_z_multiplier
- super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
+ super().__init__(cfg=cfg, batch_augment=batch_augment)
+ # initialise the auto-encoder mixin (recon handler, model, enc, dec, etc.)
+ self._init_ae_mixin(model=model)
# vae distribution
self.__latents_handler = make_latent_distribution(self.cfg.latent_distribution, kl_mode=self.cfg.kl_loss_mode, reduction=self.cfg.loss_reduction)
@@ -124,14 +121,12 @@ def do_training_step(self, batch, batch_idx):
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
# latent distribution parameterizations
ds_posterior, ds_prior = map_all(self.encode_dists, xs, collect_returned=True)
- # [HOOK] disable learnt scale values
- ds_posterior, ds_prior = self._hook_intercept_ds_disable_scale(ds_posterior, ds_prior)
# [HOOK] intercept latent parameterizations
ds_posterior, ds_prior, logs_intercept_ds = self.hook_intercept_ds(ds_posterior, ds_prior)
# sample from dists
zs_sampled = tuple(d.rsample() for d in ds_posterior)
# reconstruct without the final activation
- xs_partial_recon = map_all(self.decode_partial, detach_all(zs_sampled, if_=self.cfg.disable_decoder))
+ xs_partial_recon = map_all(self.decode_partial, detach_all(zs_sampled, if_=self.cfg.detach_decoder))
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
# LOSS
@@ -149,45 +144,23 @@ def do_training_step(self, batch, batch_idx):
if not self.cfg.disable_reg_loss: loss += reg_loss
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
- # return values
- return loss, {
+ # log general
+ self.log_dict({
**logs_intercept_ds,
**logs_recon,
**logs_reg,
**logs_aug,
- 'recon_loss': recon_loss,
- 'reg_loss': reg_loss,
- 'aug_loss': aug_loss,
- # ratios
- 'ratio_reg': (reg_loss / loss) if (loss != 0) else 0,
- 'ratio_rec': (recon_loss / loss) if (loss != 0) else 0,
- 'ratio_aug': (aug_loss / loss) if (loss != 0) else 0,
- }
+ })
- # --------------------------------------------------------------------- #
- # Delete AE Hooks #
- # --------------------------------------------------------------------- #
-
- @final
- def hook_ae_intercept_zs(self, zs: Sequence[torch.Tensor]) -> Tuple[Sequence[torch.Tensor], Dict[str, Any]]:
- raise RuntimeError('This function should never be used or overridden by VAE methods!') # pragma: no cover
+ # log progress bar
+ self.log_dict({
+ 'recon_loss': float(recon_loss),
+ 'reg_loss': float(reg_loss),
+ 'aug_loss': float(aug_loss),
+ }, prog_bar=True)
- @final
- def hook_ae_compute_ave_aug_loss(self, zs: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
- raise RuntimeError('This function should never be used or overridden by VAE methods!') # pragma: no cover
-
- # --------------------------------------------------------------------- #
- # Private Hooks #
- # --------------------------------------------------------------------- #
-
- def _hook_intercept_ds_disable_scale(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]):
- # disable posterior scales
- if self.cfg.disable_posterior_scale is not None:
- for d_posterior in ds_posterior:
- assert isinstance(d_posterior, (Normal, Laplace))
- d_posterior.scale = torch.full_like(d_posterior.scale, fill_value=self.cfg.disable_posterior_scale)
- # return modified values
- return ds_posterior, ds_prior
+ # return values
+ return loss
# --------------------------------------------------------------------- #
# Overrideable Hooks #
@@ -199,6 +172,14 @@ def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequ
def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
return 0, {}
+ def compute_ave_recon_loss(self, xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
+ # compute reconstruction loss
+ pixel_loss = self.recon_handler.compute_ave_loss_from_partial(xs_partial_recon, xs_targ)
+ # return logs
+ return pixel_loss, {
+ 'pixel_loss': pixel_loss
+ }
+
def compute_ave_reg_loss(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]:
# compute regularization loss (kl divergence)
kl_loss = self.latents_handler.compute_ave_kl_loss(ds_posterior, ds_prior, zs_sampled)
@@ -208,7 +189,7 @@ def compute_ave_reg_loss(self, ds_posterior: Sequence[Distribution], ds_prior: S
}
# --------------------------------------------------------------------- #
- # VAE - Encoding - Overrides AE #
+ # VAE Model Utility Functions (Visualisation) #
# --------------------------------------------------------------------- #
@final
@@ -218,6 +199,20 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
z = self.latents_handler.encoding_to_representation(z_raw)
return z
+ @final
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ """Decode latent vector z into reconstruction x_recon (useful for visualisation)"""
+ return self.recon_handler.activate(self._model.decode(z))
+
+ @final
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
+ """Feed through the full deterministic model (useful for visualisation)"""
+ return self.decode(self.encode(batch))
+
+ # --------------------------------------------------------------------- #
+ # VAE Model Utility Functions (Training) #
+ # --------------------------------------------------------------------- #
+
@final
def encode_dists(self, x: torch.Tensor) -> Tuple[Distribution, Distribution]:
"""Get parametrisations of the latent distributions, which are sampled from during training."""
@@ -225,6 +220,11 @@ def encode_dists(self, x: torch.Tensor) -> Tuple[Distribution, Distribution]:
z_posterior, z_prior = self.latents_handler.encoding_to_dists(z_raw)
return z_posterior, z_prior
+ @final
+ def decode_partial(self, z: torch.Tensor) -> torch.Tensor:
+ """Decode latent vector z into partial reconstructions that exclude the final activation if there is one."""
+ return self._model.decode(z)
+
# ========================================================================= #
# END #
diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py
index 98099e76..5cc8994a 100644
--- a/disent/metrics/__init__.py
+++ b/disent/metrics/__init__.py
@@ -33,24 +33,3 @@
# ========================================================================= #
# Fast Metric Settings #
# ========================================================================= #
-
-
-# helper imports
-from disent.util.function import wrapped_partial as _wrapped_partial
-
-
-FAST_METRICS = {
- 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'),
- 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds
- 'mig': _wrapped_partial(metric_mig, num_train=2000),
- 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000),
- 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000),
-}
-
-DEFAULT_METRICS = {
- 'dci': metric_dci,
- 'factor_vae': metric_factor_vae,
- 'mig': metric_mig,
- 'sap': metric_sap,
- 'unsupervised': metric_unsupervised,
-}
diff --git a/disent/metrics/_dci.py b/disent/metrics/_dci.py
index 2c48bcd5..7138bb17 100644
--- a/disent/metrics/_dci.py
+++ b/disent/metrics/_dci.py
@@ -35,6 +35,8 @@
import scipy
import scipy.stats
+from disent.metrics.utils import make_metric
+
log = logging.getLogger(__name__)
@@ -44,6 +46,7 @@
# ========================================================================= #
+@make_metric('dci', fast_kwargs=dict(num_train=1000, num_test=500))
def metric_dci(
dataset: DisentDataset,
representation_function: callable,
@@ -87,10 +90,10 @@ def _compute_dci(mus_train, ys_train, mus_test, ys_test, boost_mode='sklearn', s
assert importance_matrix.shape[0] == mus_train.shape[0]
assert importance_matrix.shape[1] == ys_train.shape[0]
return {
- "dci.informativeness_train": train_err,
- "dci.informativeness_test": test_err,
- "dci.disentanglement": _disentanglement(importance_matrix),
- "dci.completeness": _completeness(importance_matrix),
+ "dci.informativeness_train": train_err, # "dci.explicitness" -- Measuring Disentanglement: A Review of Metrics
+ "dci.informativeness_test": test_err, # "dci.explicitness" -- Measuring Disentanglement: A Review of Metrics
+ "dci.disentanglement": _disentanglement(importance_matrix), # "dci.modularity" -- Measuring Disentanglement: A Review of Metrics
+ "dci.completeness": _completeness(importance_matrix), # "dci.compactness" -- Measuring Disentanglement: A Review of Metrics
}
diff --git a/disent/metrics/_factor_vae.py b/disent/metrics/_factor_vae.py
index 66ffc67a..5576a4f2 100644
--- a/disent/metrics/_factor_vae.py
+++ b/disent/metrics/_factor_vae.py
@@ -32,6 +32,7 @@
from disent.dataset import DisentDataset
from disent.metrics import utils
+from disent.metrics.utils import make_metric
from disent.util import to_numpy
@@ -43,6 +44,7 @@
# ========================================================================= #
+@make_metric('factor_vae', fast_kwargs=dict(num_train=700, num_eval=350, num_variance_estimate=1000)) # may not be accurate, but it just takes waay too long otherwise 20+ seconds
def metric_factor_vae(
dataset: DisentDataset,
representation_function: callable,
@@ -124,8 +126,8 @@ def metric_factor_vae(
eval_accuracy = np.sum(eval_votes[classifier, other_index]) * 1. / np.sum(eval_votes)
return {
- "factor_vae.train_accuracy": train_accuracy,
- "factor_vae.eval_accuracy": eval_accuracy,
+ "factor_vae.train_accuracy": train_accuracy, # "z-min variance" -- Measuring Disentanglement: A Review of Metrics
+ "factor_vae.eval_accuracy": eval_accuracy, # "z-min variance" -- Measuring Disentanglement: A Review of Metrics
"factor_vae.num_active_dims": len(active_dims),
}
diff --git a/disent/metrics/_mig.py b/disent/metrics/_mig.py
index 3e0b1a87..4ccd0550 100644
--- a/disent/metrics/_mig.py
+++ b/disent/metrics/_mig.py
@@ -32,6 +32,7 @@
from disent.dataset import DisentDataset
from disent.metrics import utils
+from disent.metrics.utils import make_metric
log = logging.getLogger(__name__)
@@ -42,6 +43,7 @@
# ========================================================================= #
+@make_metric('mig', fast_kwargs=dict(num_train=2000))
def metric_mig(
dataset: DisentDataset,
representation_function,
@@ -76,5 +78,5 @@ def _compute_mig(mus_train, ys_train):
entropy = utils.discrete_entropy(ys_train)
sorted_m = np.sort(m, axis=0)[::-1]
return {
- "mig.discrete_score": np.mean(np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:]))
+ "mig.discrete_score": np.mean(np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:])) # "modularity: MIG" -- Measuring Disentanglement: A Review of Metrics
}
diff --git a/disent/metrics/_sap.py b/disent/metrics/_sap.py
index 36fad48d..09722d06 100644
--- a/disent/metrics/_sap.py
+++ b/disent/metrics/_sap.py
@@ -33,6 +33,7 @@
from disent.dataset import DisentDataset
from disent.metrics import utils
+from disent.metrics.utils import make_metric
log = logging.getLogger(__name__)
@@ -43,6 +44,7 @@
# ========================================================================= #
+@make_metric('sap', fast_kwargs=dict(num_train=2000, num_test=1000))
def metric_sap(
dataset: DisentDataset,
representation_function,
@@ -80,7 +82,7 @@ def _compute_sap(mus, ys, mus_test, ys_test, continuous_factors):
sap_score = _compute_avg_diff_top_two(score_matrix)
log.debug("SAP score: %.2g", sap_score)
return {
- "sap.score": sap_score
+ "sap.score": sap_score # "compactness: SAP" -- Measuring Disentanglement: A Review of Metrics
}
diff --git a/disent/metrics/_unsupervised.py b/disent/metrics/_unsupervised.py
index c872f110..e351e2ef 100644
--- a/disent/metrics/_unsupervised.py
+++ b/disent/metrics/_unsupervised.py
@@ -31,6 +31,8 @@
from disent.dataset import DisentDataset
from disent.metrics import utils
+from disent.metrics.utils import make_metric
+
log = logging.getLogger(__name__)
@@ -40,6 +42,7 @@
# ========================================================================= #
+@make_metric('unsupervised', fast_kwargs=dict(num_train=2000))
def metric_unsupervised(
dataset: DisentDataset,
representation_function,
diff --git a/disent/metrics/utils.py b/disent/metrics/utils.py
index e8c6b2fe..9bd28569 100644
--- a/disent/metrics/utils.py
+++ b/disent/metrics/utils.py
@@ -23,16 +23,94 @@
Utility functions that are useful for the different metrics.
"""
+from numbers import Number
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Generic
+from typing import Optional
+from typing import Protocol
+from typing import TypeVar
+from typing import Union
+
import numpy as np
import sklearn
from tqdm import tqdm
from disent.dataset import DisentDataset
from disent.util import to_numpy
+from disent.util.function import wrapped_partial
+
+
+# ========================================================================= #
+# Metric Wrapper #
+# ========================================================================= #
+
+
+T = TypeVar('T')
+
+
+class Metric(Generic[T]):
+
+ def __init__(
+ self,
+ name: str,
+ metric_fn: T, # Callable[[...], Dict[str, Number]]
+ default_kwargs: Optional[Dict[str, Any]] = None,
+ fast_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ self._name = name
+ self._orig_fn = metric_fn
+ self._metric_fn_default = wrapped_partial(self._orig_fn, **(default_kwargs if default_kwargs else {}))
+ self._metric_fn_fast = wrapped_partial(self._orig_fn, **(fast_kwargs if fast_kwargs else {}))
+
+ # How do we get a type hint for `__call__` so that its signature matches `T`?
+ def __call__(self, *args, **kwargs) -> Dict[str, Number]:
+ return self._metric_fn_default(*args, **kwargs)
+
+ @property
+ def compute(self) -> T:
+ return self._metric_fn_default
+
+ @property
+ def compute_fast(self) -> T:
+ return self._metric_fn_fast
+
+ @property
+ def unwrap(self) -> T:
+ return self._orig_fn
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ def __str__(self):
+ return f'metric-{self.name}'
+
+
+def make_metric(
+ name: str,
+ default_kwargs: Optional[Dict[str, Any]] = None,
+ fast_kwargs: Optional[Dict[str, Any]] = None,
+) -> Callable[[T], Union[Metric[T], T]]:
+ """
+ Metrics should be decorated using this function to set defaults!
+ Two versions of the metric should exist.
+ 1. Recommended settings
+ - This should give reliable results, but may be very slow, multiple minutes to half an
+ hour or more for some metrics depending on the underlying model, data and ground-truth factors.
+ 2. Faster settings
+ - This should give a decent results, but should be decently fast, a few seconds/minutes at most.
+ This is not used for testing
+ """
+ # `Union[Metric[T], T]` is hack to get type hint on `__call__`
+ def _wrap_fn_as_metric(metric_fn: T) -> Union[Metric[T], T]:
+ return Metric(name=name, metric_fn=metric_fn, default_kwargs=default_kwargs, fast_kwargs=fast_kwargs)
+ return _wrap_fn_as_metric
# ========================================================================= #
-# utils #
+# utils #
# ========================================================================= #
diff --git a/disent/nn/functional/__init__.py b/disent/nn/functional/__init__.py
index 4ac90f16..8af43526 100644
--- a/disent/nn/functional/__init__.py
+++ b/disent/nn/functional/__init__.py
@@ -47,6 +47,12 @@
from disent.nn.functional._mean import torch_mean_geometric
from disent.nn.functional._mean import torch_mean_harmonic
+from disent.nn.functional._norm import torch_norm
+from disent.nn.functional._norm import torch_dist
+from disent.nn.functional._norm import torch_norm_euclidean
+from disent.nn.functional._norm import torch_norm_manhattan
+from disent.nn.functional._norm import torch_dist_hamming
+
from disent.nn.functional._other import torch_normalize
from disent.nn.functional._other import torch_nan_to_num
from disent.nn.functional._other import torch_unsqueeze_l
diff --git a/disent/nn/functional/_mean.py b/disent/nn/functional/_mean.py
index 00bdf0a0..6493994e 100644
--- a/disent/nn/functional/_mean.py
+++ b/disent/nn/functional/_mean.py
@@ -41,12 +41,14 @@
_NEG_INF = float('-inf')
_GENERALIZED_MEAN_MAP = {
+ 'inf': _POS_INF,
'maximum': _POS_INF,
'quadratic': 2,
'arithmetic': 1,
'geometric': 0,
'harmonic': -1,
'minimum': _NEG_INF,
+ '-inf': _NEG_INF,
}
@@ -55,7 +57,7 @@
# ========================================================================= #
-def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1, keepdim: bool = False):
+def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[float, str] = 1, keepdim: bool = False):
"""
Compute the generalised mean.
- p is the power
@@ -69,9 +71,9 @@ def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[
p = _GENERALIZED_MEAN_MAP[p]
# compute the specific extreme cases
if p == _POS_INF:
- return torch.max(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.max(xs, keepdim=keepdim)
+ return torch.amax(xs, dim=dim, keepdim=keepdim)
elif p == _NEG_INF:
- return torch.min(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.min(xs, keepdim=keepdim)
+ return torch.amin(xs, dim=dim, keepdim=keepdim)
# compute the number of elements being averaged
if dim is None:
dim = list(range(xs.ndim))
diff --git a/disent/nn/functional/_norm.py b/disent/nn/functional/_norm.py
new file mode 100644
index 00000000..6c972482
--- /dev/null
+++ b/disent/nn/functional/_norm.py
@@ -0,0 +1,129 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2021 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import warnings
+from typing import List
+from typing import Optional
+from typing import Union
+
+import torch
+
+
+# ========================================================================= #
+# Types #
+# ========================================================================= #
+
+
+_DimTypeHint = Optional[Union[int, List[int]]]
+
+_POS_INF = float('inf')
+_NEG_INF = float('-inf')
+
+_P_NORM_MAP = {
+ 'inf': _POS_INF,
+ 'maximum': _POS_INF,
+ 'euclidean': 2,
+ 'manhattan': 1,
+ # p < 1 is not a valid norm! but allow it if we set `unbounded_p = True`
+ 'hamming': 0,
+ 'minimum': _NEG_INF,
+ '-inf': _NEG_INF,
+}
+
+
+# ========================================================================= #
+# p-Norm Functions #
+# -- very similar to the generalised mean #
+# ========================================================================= #
+
+
+def torch_dist(xs: torch.Tensor, dim: _DimTypeHint = -1, p: Union[float, str] = 1, keepdim: bool = False):
+ """
+ Like torch_norm, but allows arbitrary p values that may
+ result in the returned values not being a valid norm.
+ - norm's require p >= 1
+ """
+ if isinstance(p, str):
+ p = _P_NORM_MAP[p]
+ # get absolute values
+ xs = torch.abs(xs)
+ # compute the specific extreme cases
+ # -- its kind of odd that the p-norm and generalised mean converge to the
+ # same values, just from different directions!
+ if p == _POS_INF:
+ return torch.amax(xs, dim=dim, keepdim=keepdim)
+ elif p == _NEG_INF:
+ return torch.amin(xs, dim=dim, keepdim=keepdim)
+ # get the dimensions
+ if dim is None:
+ dim = list(range(xs.ndim))
+ # warn if the type is wrong
+ if p != 1:
+ if xs.dtype != torch.float64:
+ warnings.warn(f'Input tensor to p-norm might not have the required precision, type is {xs.dtype} not {torch.float64}.')
+ # compute the specific cases
+ if p == 0:
+ # hamming distance -- number of non-zero entries of the vector
+ return torch.sum(xs != 0, dim=dim, keepdim=keepdim, dtype=xs.dtype)
+ if p == 1:
+ # manhattan distance
+ return torch.sum(xs, dim=dim, keepdim=keepdim)
+ else:
+ # p-norm (if p==2, then euclidean distance)
+ return torch.sum(xs ** p, dim=dim, keepdim=keepdim) ** (1/p)
+
+
+def torch_norm(xs: torch.Tensor, dim: _DimTypeHint = -1, p: Union[float, str] = 1, keepdim: bool = False):
+ """
+ Compute the generalised p-norm over the given dimension of a vector!
+ - p values must be >= 1
+
+ Closely related to the generalised mean.
+ - p-norm: (sum(|x|^p)) ^ (1/p)
+ - gen-mean: (1/n * sum(x ^ p)) ^ (1/p)
+ """
+ if isinstance(p, str):
+ p = _P_NORM_MAP[p]
+ # check values
+ if p < 1:
+ raise ValueError(f'p-norm cannot have a p value less than 1, got: {repr(p)}, to bypass this error set `unbounded_p=True`.')
+ # return norm
+ return torch_dist(xs=xs, dim=dim, p=p, keepdim=keepdim)
+
+
+def torch_norm_euclidean(xs, dim: _DimTypeHint = -1, keepdim: bool = False):
+ return torch_dist(xs, dim=dim, p='euclidean', keepdim=keepdim)
+
+
+def torch_norm_manhattan(xs, dim: _DimTypeHint = -1, keepdim: bool = False):
+ return torch_dist(xs, dim=dim, p='manhattan', keepdim=keepdim)
+
+
+def torch_dist_hamming(xs, dim: _DimTypeHint = -1, keepdim: bool = False):
+ return torch_dist(xs, dim=dim, p='hamming', keepdim=keepdim)
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/nn/functional/_pca.py b/disent/nn/functional/_pca.py
index 33212f41..4cb40d76 100644
--- a/disent/nn/functional/_pca.py
+++ b/disent/nn/functional/_pca.py
@@ -30,14 +30,16 @@
# ========================================================================= #
-def torch_pca_eig(X, center=True, scale=False):
+def torch_pca_eig(X, center=True, scale=False, zero_negatives=False):
"""
perform PCA over X
- X is of size (num_points, vec_size)
NOTE: unlike PCA_svd, the number of vectors/values returned is always: vec_size
+
+ WARNING: this may be incorrect!
"""
- n, _ = X.shape
+ n, m = X.shape
# center points along axes
if center:
X = X - X.mean(dim=0)
@@ -47,10 +49,14 @@ def torch_pca_eig(X, center=True, scale=False):
scaling = torch.sqrt(1 / torch.diagonal(covariance))
covariance = torch.mm(torch.diagflat(scaling), covariance)
# compute eigen values and eigen vectors
- eigenvalues, eigenvectors = torch.eig(covariance, True)
+ eigenvalues, eigenvectors = torch.linalg.eig(covariance)
# sort components by decreasing variance
- components = eigenvectors.T
- explained_variance = eigenvalues[:, 0]
+ components = torch.real(eigenvectors.T) # TODO: handle imaginary numbers!
+ explained_variance = torch.real(eigenvalues) # TODO: handle imaginary numbers!
+ # handle n < m -- numerical stability issues return negative values!
+ # maybe this should just zero out the negatives instead, they don't contribute!?
+ explained_variance = torch.abs(explained_variance)
+ # return sorted
idxs = torch.argsort(explained_variance, descending=True)
return components[idxs], explained_variance[idxs]
@@ -75,7 +81,10 @@ def torch_pca_svd(X, center=True):
return components, explained_variance
-def torch_pca(X, center=True, mode='svd'):
+def torch_pca(X, center=True, mode='svd') -> (torch.Tensor, torch.Tensor):
+ # number of values returned may differ depending on the method!
+ # -- svd returns: min(num, z_size)
+ # -- eig returns: num
if mode == 'svd':
return torch_pca_svd(X, center=center)
elif mode == 'eig':
diff --git a/disent/nn/loss/reduction.py b/disent/nn/loss/reduction.py
index a1f08efc..ea021b33 100644
--- a/disent/nn/loss/reduction.py
+++ b/disent/nn/loss/reduction.py
@@ -78,6 +78,7 @@ def get_mean_loss_scale(x: torch.Tensor, reduction: str):
}
+# TODO: this is duplicated in research ... pairwise_loss
# applying this function and then taking the
# mean should give the same result as loss_reduction
def batch_loss_reduction(tensor: torch.Tensor, reduction_dtype=None, reduction='mean') -> torch.Tensor:
diff --git a/disent/nn/loss/softsort.py b/disent/nn/loss/softsort.py
index 2f50d16e..19158ae4 100644
--- a/disent/nn/loss/softsort.py
+++ b/disent/nn/loss/softsort.py
@@ -104,7 +104,7 @@ def torch_soft_sort(
dims: Union[int, Tuple[int, ...]] = -1,
regularization='l2',
regularization_strength=1.0,
- dims_at_end=False,
+ leave_dims_at_end=False,
):
# we import it locally so that we don't have to install this
import torchsort
@@ -113,7 +113,7 @@ def torch_soft_sort(
# sort the last dimension of the 2D tensors
tensor = torchsort.soft_sort(tensor, regularization=regularization, regularization_strength=regularization_strength)
# undo the reorder operation
- if dims_at_end:
+ if leave_dims_at_end:
return tensor
return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims)
diff --git a/disent/nn/loss/triplet.py b/disent/nn/loss/triplet.py
index f1e55f9d..4ae29afd 100644
--- a/disent/nn/loss/triplet.py
+++ b/disent/nn/loss/triplet.py
@@ -70,7 +70,7 @@ def dist_triplet_sigmoid_loss(pos_delta, neg_delta, margin_min=None, margin_max=
https://arxiv.org/pdf/2003.14021.pdf
"""
if margin_min is not None:
- warnings.warn('triplet_loss does not support margin_min')
+ warnings.warn('triplet_sigmoid_loss does not support margin_min')
p_dist = torch.norm(pos_delta, p=p, dim=-1)
n_dist = torch.norm(neg_delta, p=p, dim=-1)
loss = torch.sigmoid((1/margin_max) * (p_dist - n_dist))
@@ -80,6 +80,32 @@ def dist_triplet_sigmoid_loss(pos_delta, neg_delta, margin_min=None, margin_max=
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
+def triplet_soft_loss(anc, pos, neg, margin_min=None, margin_max=None, p=1):
+ """
+ Triplet Loss With Soft-Margin
+ https://arxiv.org/pdf/1703.07737.pdf
+ """
+ return dist_triplet_soft_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p)
+
+
+def dist_triplet_soft_loss(pos_delta, neg_delta, margin_min=None, margin_max=None, p=1):
+ """
+ Triplet Loss With Soft-Margin
+ https://arxiv.org/pdf/1703.07737.pdf
+ """
+ if margin_min is not None:
+ warnings.warn('triplet_soft_loss does not support margin_min')
+ if margin_max is not None:
+ warnings.warn('triplet_soft_loss does not support margin_max')
+ p_dist = torch.norm(pos_delta, p=p, dim=-1)
+ n_dist = torch.norm(neg_delta, p=p, dim=-1)
+ loss = torch.log(1 + torch.exp(p_dist - n_dist))
+ return loss.mean()
+
+
+# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- #
+
+
# def elem_triplet_loss(anc, pos, neg, margin_min=None, margin_max=1., p=1):
# """
# Element-Wise Triplet Loss
@@ -132,6 +158,7 @@ def min_clamped_triplet_loss(anc, pos, neg, margin_min=0.01, margin_max=1., p=1)
"""
Min Margin Triplet Loss
TODO: is this better, or clamped_triplet_loss?
+ TODO: could take idea from soft-margin to make this continuously differentiable?
"""
return dist_min_clamped_triplet_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p)
@@ -140,6 +167,7 @@ def dist_min_clamped_triplet_loss(pos_delta, neg_delta, margin_min=0.01, margin_
"""
Min Margin Triplet Loss
TODO: is this better, or dist_clamped_triplet_loss?
+ TODO: could take idea from soft-margin to make this continuously differentiable?
"""
p_dist = torch.norm(pos_delta, p=p, dim=-1)
n_dist = torch.norm(neg_delta, p=p, dim=-1)
@@ -153,6 +181,7 @@ def split_clamped_triplet_loss(anc, pos, neg, margin_min=0.01, margin_max=1., p=
"""
Min Margin Triplet Loss
TODO: is this better, or min_clamp_triplet_loss?
+ TODO: could take idea from soft-margin to make this continuously differentiable?
"""
return dist_split_clamped_triplet_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p)
@@ -161,6 +190,7 @@ def dist_split_clamped_triplet_loss(pos_delta, neg_delta, margin_min=0.01, margi
"""
Min Margin Triplet Loss
TODO: is this better, or dist_min_clamp_triplet_loss?
+ TODO: could take idea from soft-margin to make this continuously differentiable?
"""
p_dist = torch.norm(pos_delta, p=p, dim=-1)
n_dist = torch.norm(neg_delta, p=p, dim=-1)
@@ -212,12 +242,13 @@ class TripletLossConfig(object):
triplet_margin_min: float = 0.1
triplet_margin_max: float = 10
triplet_scale: float = 100
- triplet_p: int = 2
+ triplet_p: float = 2
_TRIPLET_LOSSES = {
'triplet': triplet_loss,
'triplet_sigmoid': triplet_sigmoid_loss,
+ 'triplet_soft': triplet_soft_loss,
# 'elem_triplet': elem_triplet_loss,
# 'min_margin_triplet': min_margin_triplet_loss,
'min_clamped_triplet': min_clamped_triplet_loss,
@@ -229,6 +260,7 @@ class TripletLossConfig(object):
_DIST_TRIPLET_LOSSES = {
'triplet': dist_triplet_loss,
'triplet_sigmoid': dist_triplet_sigmoid_loss,
+ 'triplet_soft': dist_triplet_soft_loss,
# 'elem_triplet': dist_elem_triplet_loss,
# 'min_margin_triplet': dist_min_margin_triplet_loss,
'min_clamped_triplet': dist_min_clamped_triplet_loss,
@@ -281,11 +313,3 @@ def compute_dist_triplet_loss(zs_deltas: Sequence[torch.Tensor], cfg: TripletCon
# ========================================================================= #
# END #
# ========================================================================= #
-
-
-
-
-
-
-
-
diff --git a/disent/nn/loss/triplet_mining.py b/disent/nn/loss/triplet_mining.py
index 84f71ce9..0eaa0ac6 100644
--- a/disent/nn/loss/triplet_mining.py
+++ b/disent/nn/loss/triplet_mining.py
@@ -157,8 +157,13 @@ def configured_idx_mine(
cfg: SampledTripletMineCfgProto,
pairwise_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], # should return arrays with ndim == 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # TODO: SIMPLIFY THIS FUNCTION HIERARCHY, THERE ARE A LOT OF UNNECESSARY CALLS!
# TODO: this function is quite useless, its easier to just use configured_mine_random_mode
+ # skip mining if mode is None!
+ if cfg.overlap_mine_triplet_mode == 'none':
+ return a_idxs, p_idxs, n_idxs
# compute differences
+ # TODO: this is computationally expensive! sometimes the dist_ap and dist_an may not be used depending on the mode!
dist_ap = pairwise_loss_fn(x_targ[a_idxs], x_targ[p_idxs])
dist_an = pairwise_loss_fn(x_targ[a_idxs], x_targ[n_idxs])
# mine indices
diff --git a/disent/nn/modules.py b/disent/nn/modules.py
index 9de5f574..fdeac6cb 100644
--- a/disent/nn/modules.py
+++ b/disent/nn/modules.py
@@ -43,11 +43,12 @@ def forward(self, *args, **kwargs):
class DisentLightningModule(pl.LightningModule):
-
- def _forward_unimplemented(self, *args):
- # Annoying fix applied by torch for Module.forward:
- # https://github.com/python/mypy/issues/8795
- raise RuntimeError('This should never run!')
+ # make sure we don't get complaints about the missing methods!
+ # -- we prefer to use LightningDataModule
+ train_dataloader = None
+ test_dataloader = None
+ val_dataloader = None
+ predict_dataloader = None
# ========================================================================= #
diff --git a/disent/nn/weights.py b/disent/nn/weights.py
index 1a09b852..6c3ef56c 100644
--- a/disent/nn/weights.py
+++ b/disent/nn/weights.py
@@ -25,6 +25,7 @@
import logging
from typing import Optional
+import torch
from torch import nn
from disent.util.strings import colors as c
@@ -37,6 +38,24 @@
# ========================================================================= #
+_WEIGHT_INIT_FNS = {
+ 'xavier_uniform': lambda weight: nn.init.xavier_uniform_(weight, gain=1.0), # gain=1
+ 'xavier_normal': lambda weight: nn.init.xavier_normal_(weight, gain=1.0), # gain=1
+ 'xavier_normal__0.1': lambda weight: nn.init.xavier_normal_(weight, gain=0.1), # gain=0.1
+ # kaiming -- also known as "He initialisation"
+ 'kaiming_uniform': lambda weight: nn.init.kaiming_uniform_(weight, a=0, mode='fan_in', nonlinearity='relu'), # fan_in, relu
+ 'kaiming_normal': lambda weight: nn.init.kaiming_normal_(weight, a=0, mode='fan_in', nonlinearity='relu'), # fan_in, relu
+ 'kaiming_normal__fan_out': lambda weight: nn.init.kaiming_normal_(weight, a=0, mode='fan_out', nonlinearity='relu'), # fan_in, relu
+ # other
+ 'orthogonal': lambda weight: nn.init.orthogonal_(weight, gain=1), # gain=1
+ 'normal': lambda weight: nn.init.normal_(weight, mean=0., std=1.), # gain=1
+ 'normal__0.1': lambda weight: nn.init.normal_(weight, mean=0., std=0.1), # gain=0.1
+ 'normal__0.01': lambda weight: nn.init.normal_(weight, mean=0., std=0.01), # gain=0.01
+ 'normal__0.001': lambda weight: nn.init.normal_(weight, mean=0., std=0.001), # gain=0.01
+}
+
+
+# TODO: clean this up! this is terrible...
def init_model_weights(model: nn.Module, mode: Optional[str] = 'xavier_normal', log_level=logging.INFO) -> nn.Module:
count = 0
@@ -44,20 +63,20 @@ def init_model_weights(model: nn.Module, mode: Optional[str] = 'xavier_normal',
if mode is None:
mode = 'default'
- def init_normal(m):
+ def _apply_init_weights(m):
nonlocal count
init, count = False, count + 1
# actually initialise!
- if mode == 'xavier_normal':
+ if mode == 'default':
+ pass
+ elif mode in _WEIGHT_INIT_FNS:
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
- nn.init.xavier_normal_(m.weight)
+ _WEIGHT_INIT_FNS[mode](m.weight)
nn.init.zeros_(m.bias)
init = True
- elif mode == 'default':
- pass
else:
- raise KeyError(f'Unknown init mode: {repr(mode)}, valid modes are: {["xavier_normal", "default"]}')
+ raise KeyError(f'Unknown init mode: {repr(mode)}, valid modes are: {["default"] + sorted(_WEIGHT_INIT_FNS)}')
# print messages
if init:
@@ -66,7 +85,7 @@ def init_normal(m):
log.log(log_level, f'| {count:03d} {c.lRED}SKIP{c.RST}: {m.__class__.__name__}')
log.log(log_level, f'Initialising Model Layers: {mode}')
- model.apply(init_normal)
+ model.apply(_apply_init_weights)
return model
diff --git a/disent/registry/__init__.py b/disent/registry/__init__.py
index 82814540..d834f1a8 100644
--- a/disent/registry/__init__.py
+++ b/disent/registry/__init__.py
@@ -34,8 +34,18 @@
eg. `DATASET.register(...options...)(your_function_or_class)`
"""
-from disent.registry._registry import Registry as _Registry
-from disent.registry._registry import LazyImport as _LazyImport
+# from disent.registry._registry import ProvidedValue
+# from disent.registry._registry import StaticImport
+# from disent.registry._registry import DictProviders
+# from disent.registry._registry import RegexProvidersSearch
+
+from disent.registry._registry import StaticValue
+from disent.registry._registry import LazyValue
+from disent.registry._registry import LazyImport
+from disent.registry._registry import Registry
+from disent.registry._registry import RegistryImports
+from disent.registry._registry import RegexConstructor
+from disent.registry._registry import RegexRegistry
# ========================================================================= #
@@ -44,15 +54,19 @@
# TODO: this is not yet used in disent.data or disent.frameworks
-DATASETS = _Registry('DATASETS')
+DATASETS: RegistryImports['torch.utils.data.Dataset'] = RegistryImports('DATASETS')
# groundtruth -- impl
-DATASETS['cars3d'] = _LazyImport('disent.dataset.data._groundtruth__cars3d')
-DATASETS['dsprites'] = _LazyImport('disent.dataset.data._groundtruth__dsprites')
-DATASETS['mpi3d'] = _LazyImport('disent.dataset.data._groundtruth__mpi3d')
-DATASETS['smallnorb'] = _LazyImport('disent.dataset.data._groundtruth__norb')
-DATASETS['shapes3d'] = _LazyImport('disent.dataset.data._groundtruth__shapes3d')
+DATASETS['cars3d_x128'] = LazyImport('disent.dataset.data._groundtruth__cars3d.Cars3dData')
+DATASETS['cars3d'] = LazyImport('disent.dataset.data._groundtruth__cars3d.Cars3d64Data')
+DATASETS['dsprites'] = LazyImport('disent.dataset.data._groundtruth__dsprites.DSpritesData')
+DATASETS['mpi3d_toy'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='toy')
+DATASETS['mpi3d_realistic'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='realistic')
+DATASETS['mpi3d_real'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='real')
+DATASETS['smallnorb_x96'] = LazyImport('disent.dataset.data._groundtruth__norb.SmallNorbData')
+DATASETS['smallnorb'] = LazyImport('disent.dataset.data._groundtruth__norb.SmallNorb64Data')
+DATASETS['shapes3d'] = LazyImport('disent.dataset.data._groundtruth__shapes3d.Shapes3dData')
# groundtruth -- impl synthetic
-DATASETS['xyobject'] = _LazyImport('disent.dataset.data._groundtruth__xyobject')
+DATASETS['xyobject'] = LazyImport('disent.dataset.data._groundtruth__xyobject.XYObjectData')
# ========================================================================= #
@@ -63,18 +77,18 @@
# TODO: this is not yet used in disent.data or disent.frameworks
# changes here should also update
-SAMPLERS = _Registry('SAMPLERS')
+SAMPLERS: RegistryImports['disent.dataset.sampling.BaseDisentSampler'] = RegistryImports('SAMPLERS')
# [ground truth samplers]
-SAMPLERS['gt_dist'] = _LazyImport('disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler')
-SAMPLERS['gt_pair'] = _LazyImport('disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler')
-SAMPLERS['gt_pair_orig'] = _LazyImport('disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler')
-SAMPLERS['gt_single'] = _LazyImport('disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler')
-SAMPLERS['gt_triple'] = _LazyImport('disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler')
+SAMPLERS['gt_dist'] = LazyImport('disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler')
+SAMPLERS['gt_pair'] = LazyImport('disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler')
+SAMPLERS['gt_pair_orig'] = LazyImport('disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler')
+SAMPLERS['gt_single'] = LazyImport('disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler')
+SAMPLERS['gt_triple'] = LazyImport('disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler')
# [any dataset samplers]
-SAMPLERS['single'] = _LazyImport('disent.dataset.sampling._single.SingleSampler')
-SAMPLERS['random'] = _LazyImport('disent.dataset.sampling._random__any.RandomSampler')
+SAMPLERS['single'] = LazyImport('disent.dataset.sampling._single.SingleSampler')
+SAMPLERS['random'] = LazyImport('disent.dataset.sampling._random__any.RandomSampler')
# [episode samplers]
-SAMPLERS['random_episode'] = _LazyImport('disent.dataset.sampling._random__episodes.RandomEpisodeSampler')
+SAMPLERS['random_episode'] = LazyImport('disent.dataset.sampling._random__episodes.RandomEpisodeSampler')
# ========================================================================= #
@@ -87,19 +101,19 @@
# TODO: this is not yet used in disent.frameworks
-FRAMEWORKS = _Registry('FRAMEWORKS')
+FRAMEWORKS: RegistryImports['disent.frameworks.DisentFramework'] = RegistryImports('FRAMEWORKS')
# [AE]
-FRAMEWORKS['tae'] = _LazyImport('disent.frameworks.ae._supervised__tae.TripletAe')
-FRAMEWORKS['ae'] = _LazyImport('disent.frameworks.ae._unsupervised__ae.Ae')
+FRAMEWORKS['tae'] = LazyImport('disent.frameworks.ae._supervised__tae.TripletAe')
+FRAMEWORKS['ae'] = LazyImport('disent.frameworks.ae._unsupervised__ae.Ae')
# [VAE]
-FRAMEWORKS['tvae'] = _LazyImport('disent.frameworks.vae._supervised__tvae.TripletVae')
-FRAMEWORKS['betatc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae')
-FRAMEWORKS['beta_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betavae.BetaVae')
-FRAMEWORKS['dfc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dfcvae.DfcVae')
-FRAMEWORKS['dip_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dipvae.DipVae')
-FRAMEWORKS['info_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae')
-FRAMEWORKS['vae'] = _LazyImport('disent.frameworks.vae._unsupervised__vae.Vae')
-FRAMEWORKS['ada_vae'] = _LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae')
+FRAMEWORKS['tvae'] = LazyImport('disent.frameworks.vae._supervised__tvae.TripletVae')
+FRAMEWORKS['betatc_vae'] = LazyImport('disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae')
+FRAMEWORKS['beta_vae'] = LazyImport('disent.frameworks.vae._unsupervised__betavae.BetaVae')
+FRAMEWORKS['dfc_vae'] = LazyImport('disent.frameworks.vae._unsupervised__dfcvae.DfcVae')
+FRAMEWORKS['dip_vae'] = LazyImport('disent.frameworks.vae._unsupervised__dipvae.DipVae')
+FRAMEWORKS['info_vae'] = LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae')
+FRAMEWORKS['vae'] = LazyImport('disent.frameworks.vae._unsupervised__vae.Vae')
+FRAMEWORKS['ada_vae'] = LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae')
# ========================================================================= #
@@ -108,27 +122,30 @@
# ========================================================================= #
-RECON_LOSSES = _Registry('RECON_LOSSES')
+RECON_LOSSES: RegexRegistry['disent.frameworks.helper.reconstructions.ReconLossHandler'] = RegexRegistry('RECON_LOSSES') # TODO: we need a regex version of RegistryImports
# [STANDARD LOSSES]
-RECON_LOSSES['mse'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMse') # from the normal distribution - real values in the range [0, 1]
-RECON_LOSSES['mae'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMae') # mean absolute error
+RECON_LOSSES['mse'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMse') # from the normal distribution - real values in the range [0, 1]
+RECON_LOSSES['mae'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMae') # mean absolute error
# [STANDARD DISTRIBUTIONS]
-RECON_LOSSES['bce'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBce') # from the bernoulli distribution - binary values in the set {0, 1}
-RECON_LOSSES['bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli') # reduces to bce - binary values in the set {0, 1}
-RECON_LOSSES['c_bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli') # bernoulli with a computed offset to handle values in the range [0, 1]
-RECON_LOSSES['normal'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerNormal') # handle all real values
+RECON_LOSSES['bce'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBce') # from the bernoulli distribution - binary values in the set {0, 1}
+RECON_LOSSES['bernoulli'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli') # reduces to bce - binary values in the set {0, 1}
+RECON_LOSSES['c_bernoulli'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli') # bernoulli with a computed offset to handle values in the range [0, 1]
+RECON_LOSSES['normal'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerNormal') # handle all real values
+# [REGEX LOSSES]
+RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_l(\d+\.\d+)_k(\d+\.\d+)_norm_([a-z]+)$', example='mse_xy8_abs63_l1.0_k1.0_norm_none', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l_w_n')
+RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_norm_([a-z]+)$', example='mse_xy8_abs63_norm_none', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l1_w1_n')
+RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)$', example='mse_xy8_abs63', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l1_w1_nnone')
# ========================================================================= #
-# LATENT_DISTS - should be synchronized with: #
-# `disent/frameworks/helper/latent_distributions.py` #
+# LATENT_HANDLERS - should be synchronized with: #
+# `disent/frameworks/helper/latent_distributions.py` #
# ========================================================================= #
-# TODO: this is not yet used in disent.frameworks or disent.frameworks.helper.latent_distributions
-LATENT_DISTS = _Registry('LATENT_DISTS')
-LATENT_DISTS['normal'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal')
-LATENT_DISTS['laplace'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace')
+LATENT_HANDLERS: RegistryImports['disent.frameworks.helper.latent_distributions.LatentDistsHandler'] = RegistryImports('LATENT_HANDLERS')
+LATENT_HANDLERS['normal'] = LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal')
+LATENT_HANDLERS['laplace'] = LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace')
# ========================================================================= #
@@ -139,41 +156,40 @@
# default learning rate for each optimizer
_LR = 1e-3
-
-OPTIMIZERS = _Registry('OPTIMIZERS')
+OPTIMIZERS: RegistryImports['torch.optim.Optimizer'] = RegistryImports('OPTIMIZERS')
# [torch]
-OPTIMIZERS['adadelta'] = _LazyImport(lr=_LR, import_path='torch.optim.adadelta.Adadelta')
-OPTIMIZERS['adagrad'] = _LazyImport(lr=_LR, import_path='torch.optim.adagrad.Adagrad')
-OPTIMIZERS['adam'] = _LazyImport(lr=_LR, import_path='torch.optim.adam.Adam')
-OPTIMIZERS['adamax'] = _LazyImport(lr=_LR, import_path='torch.optim.adamax.Adamax')
-OPTIMIZERS['adam_w'] = _LazyImport(lr=_LR, import_path='torch.optim.adamw.AdamW')
-OPTIMIZERS['asgd'] = _LazyImport(lr=_LR, import_path='torch.optim.asgd.ASGD')
-OPTIMIZERS['lbfgs'] = _LazyImport(lr=_LR, import_path='torch.optim.lbfgs.LBFGS')
-OPTIMIZERS['rmsprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rmsprop.RMSprop')
-OPTIMIZERS['rprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rprop.Rprop')
-OPTIMIZERS['sgd'] = _LazyImport(lr=_LR, import_path='torch.optim.sgd.SGD')
-OPTIMIZERS['sparse_adam'] = _LazyImport(lr=_LR, import_path='torch.optim.sparse_adam.SparseAdam')
+OPTIMIZERS['adadelta'] = LazyImport(lr=_LR, import_path='torch.optim.adadelta.Adadelta')
+OPTIMIZERS['adagrad'] = LazyImport(lr=_LR, import_path='torch.optim.adagrad.Adagrad')
+OPTIMIZERS['adam'] = LazyImport(lr=_LR, import_path='torch.optim.adam.Adam')
+OPTIMIZERS['adamax'] = LazyImport(lr=_LR, import_path='torch.optim.adamax.Adamax')
+OPTIMIZERS['adam_w'] = LazyImport(lr=_LR, import_path='torch.optim.adamw.AdamW')
+OPTIMIZERS['asgd'] = LazyImport(lr=_LR, import_path='torch.optim.asgd.ASGD')
+OPTIMIZERS['lbfgs'] = LazyImport(lr=_LR, import_path='torch.optim.lbfgs.LBFGS')
+OPTIMIZERS['rmsprop'] = LazyImport(lr=_LR, import_path='torch.optim.rmsprop.RMSprop')
+OPTIMIZERS['rprop'] = LazyImport(lr=_LR, import_path='torch.optim.rprop.Rprop')
+OPTIMIZERS['sgd'] = LazyImport(lr=_LR, import_path='torch.optim.sgd.SGD')
+OPTIMIZERS['sparse_adam'] = LazyImport(lr=_LR, import_path='torch.optim.sparse_adam.SparseAdam')
# [torch_optimizer]
-OPTIMIZERS['acc_sgd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AccSGD')
-OPTIMIZERS['ada_bound'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaBound')
-OPTIMIZERS['ada_mod'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaMod')
-OPTIMIZERS['adam_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdamP')
-OPTIMIZERS['agg_mo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AggMo')
-OPTIMIZERS['diff_grad'] = _LazyImport(lr=_LR, import_path='torch_optimizer.DiffGrad')
-OPTIMIZERS['lamb'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Lamb')
+OPTIMIZERS['acc_sgd'] = LazyImport(lr=_LR, import_path='torch_optimizer.AccSGD')
+OPTIMIZERS['ada_bound'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdaBound')
+OPTIMIZERS['ada_mod'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdaMod')
+OPTIMIZERS['adam_p'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdamP')
+OPTIMIZERS['agg_mo'] = LazyImport(lr=_LR, import_path='torch_optimizer.AggMo')
+OPTIMIZERS['diff_grad'] = LazyImport(lr=_LR, import_path='torch_optimizer.DiffGrad')
+OPTIMIZERS['lamb'] = LazyImport(lr=_LR, import_path='torch_optimizer.Lamb')
# 'torch_optimizer.Lookahead' is skipped because it is wrapped
-OPTIMIZERS['novograd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.NovoGrad')
-OPTIMIZERS['pid'] = _LazyImport(lr=_LR, import_path='torch_optimizer.PID')
-OPTIMIZERS['qh_adam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHAdam')
-OPTIMIZERS['qhm'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHM')
-OPTIMIZERS['radam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RAdam')
-OPTIMIZERS['ranger'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Ranger')
-OPTIMIZERS['ranger_qh'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerQH')
-OPTIMIZERS['ranger_va'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerVA')
-OPTIMIZERS['sgd_w'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDW')
-OPTIMIZERS['sgd_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDP')
-OPTIMIZERS['shampoo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Shampoo')
-OPTIMIZERS['yogi'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Yogi')
+OPTIMIZERS['novograd'] = LazyImport(lr=_LR, import_path='torch_optimizer.NovoGrad')
+OPTIMIZERS['pid'] = LazyImport(lr=_LR, import_path='torch_optimizer.PID')
+OPTIMIZERS['qh_adam'] = LazyImport(lr=_LR, import_path='torch_optimizer.QHAdam')
+OPTIMIZERS['qhm'] = LazyImport(lr=_LR, import_path='torch_optimizer.QHM')
+OPTIMIZERS['radam'] = LazyImport(lr=_LR, import_path='torch_optimizer.RAdam')
+OPTIMIZERS['ranger'] = LazyImport(lr=_LR, import_path='torch_optimizer.Ranger')
+OPTIMIZERS['ranger_qh'] = LazyImport(lr=_LR, import_path='torch_optimizer.RangerQH')
+OPTIMIZERS['ranger_va'] = LazyImport(lr=_LR, import_path='torch_optimizer.RangerVA')
+OPTIMIZERS['sgd_w'] = LazyImport(lr=_LR, import_path='torch_optimizer.SGDW')
+OPTIMIZERS['sgd_p'] = LazyImport(lr=_LR, import_path='torch_optimizer.SGDP')
+OPTIMIZERS['shampoo'] = LazyImport(lr=_LR, import_path='torch_optimizer.Shampoo')
+OPTIMIZERS['yogi'] = LazyImport(lr=_LR, import_path='torch_optimizer.Yogi')
# ========================================================================= #
@@ -182,12 +198,12 @@
# TODO: this is not yet used in disent.util.lightning.callbacks or disent.metrics
-METRICS = _Registry('METRICS')
-METRICS['dci'] = _LazyImport('disent.metrics._dci.metric_dci')
-METRICS['factor_vae'] = _LazyImport('disent.metrics._factor_vae.metric_factor_vae')
-METRICS['mig'] = _LazyImport('disent.metrics._mig.metric_mig')
-METRICS['sap'] = _LazyImport('disent.metrics._sap.metric_sap')
-METRICS['unsupervised'] = _LazyImport('disent.metrics._unsupervised.metric_unsupervised')
+METRICS: RegistryImports['disent.metrics.utils._Metric'] = RegistryImports('METRICS')
+METRICS['dci'] = LazyImport('disent.metrics._dci.metric_dci')
+METRICS['factor_vae'] = LazyImport('disent.metrics._factor_vae.metric_factor_vae')
+METRICS['mig'] = LazyImport('disent.metrics._mig.metric_mig')
+METRICS['sap'] = LazyImport('disent.metrics._sap.metric_sap')
+METRICS['unsupervised'] = LazyImport('disent.metrics._unsupervised.metric_unsupervised')
# ========================================================================= #
@@ -196,12 +212,12 @@
# TODO: this is not yet used in disent.framework or disent.schedule
-SCHEDULES = _Registry('SCHEDULES')
-SCHEDULES['clip'] = _LazyImport('disent.schedule._schedule.ClipSchedule')
-SCHEDULES['cosine_wave'] = _LazyImport('disent.schedule._schedule.CosineWaveSchedule')
-SCHEDULES['cyclic'] = _LazyImport('disent.schedule._schedule.CyclicSchedule')
-SCHEDULES['linear'] = _LazyImport('disent.schedule._schedule.LinearSchedule')
-SCHEDULES['noop'] = _LazyImport('disent.schedule._schedule.NoopSchedule')
+SCHEDULES: RegistryImports['disent.schedule.Schedule'] = RegistryImports('SCHEDULES')
+SCHEDULES['clip'] = LazyImport('disent.schedule._schedule.ClipSchedule')
+SCHEDULES['cosine_wave'] = LazyImport('disent.schedule._schedule.CosineWaveSchedule')
+SCHEDULES['cyclic'] = LazyImport('disent.schedule._schedule.CyclicSchedule')
+SCHEDULES['linear'] = LazyImport('disent.schedule._schedule.LinearSchedule')
+SCHEDULES['noop'] = LazyImport('disent.schedule._schedule.NoopSchedule')
# ========================================================================= #
@@ -210,17 +226,28 @@
# TODO: this is not yet used in disent.framework or disent.model
-MODELS = _Registry('MODELS')
+MODELS: RegistryImports['disent.model._base.DisentLatentsModule'] = RegistryImports('MODELS')
# [DECODER]
-MODELS['encoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.EncoderConv64')
-MODELS['encoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.EncoderConv64Norm')
-MODELS['encoder_fc'] = _LazyImport('disent.model.ae._vae_fc.EncoderFC')
-MODELS['encoder_linear'] = _LazyImport('disent.model.ae._linear.EncoderLinear')
+MODELS['encoder_conv64'] = LazyImport('disent.model.ae._vae_conv64.EncoderConv64')
+MODELS['encoder_conv64norm'] = LazyImport('disent.model.ae._norm_conv64.EncoderConv64Norm')
+MODELS['encoder_fc'] = LazyImport('disent.model.ae._vae_fc.EncoderFC')
+MODELS['encoder_linear'] = LazyImport('disent.model.ae._linear.EncoderLinear')
# [ENCODER]
-MODELS['decoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.DecoderConv64')
-MODELS['decoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.DecoderConv64Norm')
-MODELS['decoder_fc'] = _LazyImport('disent.model.ae._vae_fc.DecoderFC')
-MODELS['decoder_linear'] = _LazyImport('disent.model.ae._linear.DecoderLinear')
+MODELS['decoder_conv64'] = LazyImport('disent.model.ae._vae_conv64.DecoderConv64')
+MODELS['decoder_conv64norm'] = LazyImport('disent.model.ae._norm_conv64.DecoderConv64Norm')
+MODELS['decoder_fc'] = LazyImport('disent.model.ae._vae_fc.DecoderFC')
+MODELS['decoder_linear'] = LazyImport('disent.model.ae._linear.DecoderLinear')
+
+
+# ========================================================================= #
+# HELPER registries #
+# ========================================================================= #
+
+
+# TODO: add norm support with regex?
+KERNELS: RegexRegistry['torch.Tensor'] = RegexRegistry('KERNELS')
+KERNELS.register_regex(pattern=r'^box_r(\d+)$', example='box_r31', factory_fn='disent.dataset.transform._augment._make_box_kernel')
+KERNELS.register_regex(pattern=r'^gau_r(\d+)$', example='gau_r31', factory_fn='disent.dataset.transform._augment._make_gaussian_kernel')
# ========================================================================= #
@@ -229,16 +256,17 @@
# registry of registries
-REGISTRIES = _Registry('REGISTRIES')
-REGISTRIES['DATASETS'] = DATASETS
-REGISTRIES['SAMPLERS'] = SAMPLERS
-REGISTRIES['FRAMEWORKS'] = FRAMEWORKS
-REGISTRIES['RECON_LOSSES'] = RECON_LOSSES
-REGISTRIES['LATENT_DISTS'] = LATENT_DISTS
-REGISTRIES['OPTIMIZERS'] = OPTIMIZERS
-REGISTRIES['METRICS'] = METRICS
-REGISTRIES['SCHEDULES'] = SCHEDULES
-REGISTRIES['MODELS'] = MODELS
+REGISTRIES: Registry[Registry] = Registry('REGISTRIES')
+REGISTRIES['DATASETS'] = StaticValue(DATASETS)
+REGISTRIES['SAMPLERS'] = StaticValue(SAMPLERS)
+REGISTRIES['FRAMEWORKS'] = StaticValue(FRAMEWORKS)
+REGISTRIES['RECON_LOSSES'] = StaticValue(RECON_LOSSES)
+REGISTRIES['LATENT_HANDLERS'] = StaticValue(LATENT_HANDLERS)
+REGISTRIES['OPTIMIZERS'] = StaticValue(OPTIMIZERS)
+REGISTRIES['METRICS'] = StaticValue(METRICS)
+REGISTRIES['SCHEDULES'] = StaticValue(SCHEDULES)
+REGISTRIES['MODELS'] = StaticValue(MODELS)
+REGISTRIES['KERNELS'] = StaticValue(KERNELS)
# ========================================================================= #
diff --git a/disent/registry/_registry.py b/disent/registry/_registry.py
index dbf4d911..05afba22 100644
--- a/disent/registry/_registry.py
+++ b/disent/registry/_registry.py
@@ -22,202 +22,663 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+import inspect
+import re
+from abc import ABC
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generic
+from typing import Iterator
+from typing import List
+from typing import MutableMapping
from typing import NoReturn
from typing import Optional
-from typing import Sequence
+from typing import Protocol
+from typing import Set
from typing import Tuple
from typing import TypeVar
+from typing import Union
from disent.util.function import wrapped_partial
-from disent.util.imports import import_obj_partial
from disent.util.imports import _check_and_split_path
+from disent.util.imports import import_obj
+from disent.util.imports import import_obj_partial
# ========================================================================= #
-# Basic Cached Item #
+# Type Hints #
# ========================================================================= #
+K = TypeVar('K')
+V = TypeVar('V')
T = TypeVar('T')
+AliasesHint = Union[str, Tuple[str, ...]]
+
+
+class _FactoryFn(Protocol[V]):
+ def __call__(self, *args) -> V: ...
+
+
+
+
+# ========================================================================= #
+# Provided Values #
+# ========================================================================= #
+
+
+class ProvidedValue(Generic[V], ABC):
+ """
+ Base class for providing immutable values using the `get` method.
+ - Subclasses should override this
+ """
+
+ def get(self) -> V:
+ raise NotImplementedError
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
+
+class StaticValue(ProvidedValue[V]):
+ """
+ Provide static values. Simply a see-through wrapper
+ around already generated / constant values.
+ """
-class LazyValue(object):
+ def __init__(self, value: V):
+ self._value = value
- def __init__(self, generate_fn: Callable[[], T]):
+ def get(self) -> V:
+ return self._value
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({repr(self._value)})'
+
+
+class StaticImport(StaticValue[V]):
+ def __init__(self, fn: V, *partial_args, **partial_kwargs):
+ super().__init__(wrapped_partial(fn, *partial_args, **partial_kwargs))
+
+
+class LazyValue(ProvidedValue[V]):
+ """
+ Use a function to provide a value by generating and caching
+ the result only when this value is first needed.
+ """
+
+ def __init__(self, generate_fn: Callable[[], V]):
assert callable(generate_fn)
- self._is_generated = False
self._generate_fn = generate_fn
+ self._is_generated = False
self._value = None
- def generate(self) -> T:
- # replace value -- we don't actually need caching of the
- # values since the registry replaces these items automatically,
- # but LazyValue is exposed so it might be used unexpectedly.
+ def get(self) -> V:
+ # cache the value
if not self._is_generated:
self._is_generated = True
self._value = self._generate_fn()
- self._generate_fn = None
- # return value
return self._value
+ def clear(self):
+ self._is_generated = False
+ self._value = None
+
def __repr__(self):
return f'{self.__class__.__name__}({repr(self._generate_fn)})'
+class LazyImport(LazyValue[V]):
+ """
+ Like lazy value, but instead takes in the import path to a callable object.
+ Any remaining args and kwargs are used to partially parameterize the object.
+ - The partial object is returned, and not called. This should be
+ the same as importing the value directly when `get` is called!
+ """
+
+ def __init__(self, import_path: str, *partial_args, **partial_kwargs):
+ # function imports the object when called
+ def generate_fn():
+ return import_obj_partial(import_path, *partial_args, **partial_kwargs)
+ # initialise the lazy value
+ super().__init__(generate_fn=generate_fn)
+
+
# ========================================================================= #
-# Import Helper #
+# Provided Dict #
# ========================================================================= #
-class LazyImport(LazyValue):
- def __init__(self, import_path: str, *partial_args, **partial_kwargs):
- super().__init__(
- generate_fn=lambda: import_obj_partial(import_path, *partial_args, **partial_kwargs),
- )
+class DictProviders(MutableMapping[K, V]):
+ """
+ A dictionary that only allows instances of
+ provided values to be added to it.
+ - The returned values are obtained directly from the providers
+ """
+
+ def __init__(self):
+ self._providers: Dict[K, ProvidedValue[V]] = {}
+
+ def __getitem__(self, k: K) -> V:
+ return self._getitem(k)
+
+ def __contains__(self, k: K):
+ return k in self._providers
+
+ def __setitem__(self, k: K, v: ProvidedValue[V]) -> NoReturn:
+ self._setitem(k, v)
+
+ def __delitem__(self, k: K) -> NoReturn:
+ del self._providers[k]
+
+ def __len__(self) -> int:
+ return len(self._providers)
+
+ def __iter__(self) -> Iterator[K]:
+ yield from self._providers
+
+ # allow easier overriding in subclasses without calling super() which can get confusing
+
+ def _getitem(self, k: K) -> V:
+ provider = self._providers[k]
+ return provider.get()
+
+ def _setitem(self, k: K, v: ProvidedValue[V]) -> NoReturn:
+ if not isinstance(v, ProvidedValue):
+ raise TypeError(f'Values stored in {self.__class__.__name__} must be instances of: {ProvidedValue.__name__}, got: {repr(v)}')
+ self._providers[k] = v
# ========================================================================= #
-# Registry #
+# Registry - Mixin #
# ========================================================================= #
-_NONE = object()
+class Registry(DictProviders[str, V]):
+ def __init__(self, name: str):
+ if not str.isidentifier(name):
+ raise ValueError(f'Registry names must be valid identifiers, got: {repr(name)}')
+ # initialise
+ self._name = name
+ super().__init__()
-class Registry(object):
+ @property
+ def static_examples(self) -> List[str]:
+ return list(self._providers.keys())
- def __init__(
- self,
- name: str,
- assert_valid_key: Optional[Callable[[str], NoReturn]] = None,
- assert_valid_value: Optional[Callable[[T], NoReturn]] = None,
- ):
- # checks!
- assert str.isidentifier(name), f'The given name for the registry is not a valid identifier: {repr(name)}'
- self._name = name
- assert (assert_valid_key is None) or callable(assert_valid_key), f'assert_valid_key must be None or callable'
- assert (assert_valid_value is None) or callable(assert_valid_value), f'assert_valid_value must be None or callable'
- self._assert_valid_key = assert_valid_key
- self._assert_valid_value = assert_valid_value
- # storage
- self._keys_to_values: Dict[str, Any] = {}
+ @property
+ def examples(self) -> List[str]:
+ return self.static_examples
@property
def name(self) -> str:
return self._name
- def _get_aliases(self, name, aliases, auto_alias: bool):
- if auto_alias:
- if name not in self:
- aliases = (name, *aliases)
- elif not aliases:
- raise RuntimeError(f'automatic alias: {repr(name)} already exists but no alternative aliases were specified.')
+ def __repr__(self):
+ return f'{self.__class__.__name__}({self._name})'
+
+ # --- CORE --- #
+
+ def __setitem__(self, aliases: AliasesHint, v: ProvidedValue[V]) -> NoReturn:
+ self._setitems(aliases, v)
+
+ def __getitem__(self, k: str) -> V:
+ value = self._getitem(k)
+ self._check_provided_value(value)
+ return value
+
+ def __delitem__(self, k: str) -> None:
+ raise RuntimeError(f'Registry: {repr(self.name)} does not support item deletion. Tried to remove key: {repr(k)}')
+
+ # --- HELPER --- #
+
+ def _setitems(self, aliases: AliasesHint, v: Union[V, ProvidedValue[V]]) -> None:
+ aliases = self._normalise_aliases(aliases)
+ # check all the aliases
+ for k in aliases:
+ if not str.isidentifier(k):
+ raise ValueError(f'Keys stored in registry: {repr(self.name)} must be valid identifiers, got: {repr(k)}')
+ if k in self:
+ raise RuntimeError(f'Tried to overwrite existing key: {repr(k)} in registry: {repr(self.name)}')
+ self._check_key(k)
+ # check the value
+ v = self._check_and_normalise_value(v)
+ # set all the aliases
+ for k in aliases:
+ self._setitem(k, v)
+
+ def _normalise_aliases(self, aliases: AliasesHint, check_nonempty: bool = True) -> Tuple[str]:
+ if isinstance(aliases, str):
+ aliases = (aliases,)
+ if not isinstance(aliases, tuple):
+ raise TypeError(f'Multiple aliases must be provided to registry: {repr(self.name)} as a Tuple[str], got: {repr(aliases)}')
+ if check_nonempty:
+ if len(aliases) < 1:
+ raise ValueError(f'At least one alias must be provided to registry: {repr(self.name)}, got: {repr(aliases)}')
return aliases
+ # --- OVERRIDABLE --- #
+
+ def _check_and_normalise_value(self, v: ProvidedValue[V]) -> ProvidedValue[V]:
+ return v
+
+ def _check_provided_value(self, v: V) -> NoReturn:
+ pass
+
+ def _check_key(self, k: str) -> NoReturn:
+ pass
+
+ # --- MISSING VALUES --- #
+
+ def setmissing(self, alias: AliasesHint, value: V) -> NoReturn:
+ # find missing keys
+ aliases = self._normalise_aliases(alias)
+ missing = tuple(alias for alias in aliases if (alias not in self))
+ # register missing keys
+ if missing:
+ self._setitems(missing, value)
+
+ @property
+ def setm(self) -> '_RegistrySetMissing':
+ # instead of checking values manually, at the cost of some efficiency,
+ # this allows us to register values multiple times with hardly modified notation!
+ # -- only modifies unset values
+ # set once: `REGISTRY['key'] = val`
+ # set default: `REGISTRY.setm['key'] = val`
+ return self._RegistrySetMissing(self)
+
+ class _RegistrySetMissing(object):
+ def __init__(self, registry: 'Registry'):
+ self._registry = registry
+
+ def __setitem__(self, aliases: str, v: ProvidedValue[V]) -> NoReturn:
+ self._registry.setmissing(aliases, v)
+
+
+# ========================================================================= #
+# Import Registry #
+# ========================================================================= #
+
+
+# TODO: merge this with the dynamic registry below?
+class RegistryImports(Registry[V]):
+ """
+ A registry for arbitrary imports.
+ -- supports decorating functions and classes
+ """
+
def register(
self,
- fn=_NONE,
- aliases: Sequence[str] = (),
+ aliases: Optional[AliasesHint] = None,
auto_alias: bool = True,
partial_args: Tuple[Any, ...] = None,
partial_kwargs: Dict[str, Any] = None,
) -> Callable[[T], T]:
+ """
+ Register a function or object to this registry.
+ - can be used as a decorator @register(...)
+ - automatically chooses an alias based on the function name
+ - specify defaults for the function with the args and kwargs
+ """
+ # default values
+ if aliases is None: aliases = ()
+ if partial_args is None: partial_args = ()
+ if partial_kwargs is None: partial_kwargs = {}
+ aliases = self._normalise_aliases(aliases, check_nonempty=False)
+
+ # add the function name as an alias if it does not already exist,
+ # then register the partially parameterised function as a static value
def _decorator(orig_fn):
- # try add the name of the object
- keys = self._get_aliases(orig_fn.__name__, aliases=aliases, auto_alias=auto_alias)
- # wrap function
- new_fn = orig_fn
- if (partial_args is not None) or (partial_kwargs is not None):
- new_fn = wrapped_partial(
- orig_fn,
- *(() if partial_args is None else partial_args),
- **({} if partial_kwargs is None else partial_kwargs),
- )
- # register the function
- self.register_value(value=new_fn, aliases=keys)
+ keys = self._append_auto_alias(self._get_fn_alias(orig_fn), aliases=aliases, auto_alias=auto_alias)
+ self[keys] = StaticImport(orig_fn, *partial_args, **partial_kwargs)
return orig_fn
- # handle case
- if fn is _NONE:
- return _decorator
- else:
- return _decorator(fn)
+ return _decorator
def register_import(
self,
import_path: str,
- aliases: Sequence[str] = (),
+ aliases: Optional[AliasesHint] = None,
auto_alias: bool = True,
*partial_args,
**partial_kwargs,
- ) -> 'Registry':
- (*_, name) = _check_and_split_path(import_path)
- return self.register_value(
- value=LazyImport(import_path=import_path, *partial_args, **partial_kwargs),
- aliases=self._get_aliases(name, aliases=aliases, auto_alias=auto_alias),
- )
-
- def register_value(self, value: T, aliases: Sequence[str]) -> 'Registry':
- # check keys
- if len(aliases) < 1:
- raise ValueError(f'aliases must be specified, got an empty sequence')
- for k in aliases:
- if not str.isidentifier(k):
- raise ValueError(f'alias is not a valid identifier: {repr(k)}')
- if k in self._keys_to_values:
- raise RuntimeError(f'registry already contains key: {repr(k)}')
- self.assert_valid_key(k)
- # handle lazy values -- defer checking if a lazy value
- if not isinstance(value, LazyValue):
- self.assert_valid_value(value)
- # register keys
- for k in aliases:
- self._keys_to_values[k] = value
- return self
+ ) -> NoReturn:
+ """
+ Register an import path and automatically obtain an alias from it.
+ - This is the same as: registry[(import_name, *aliases)] = LazyImport(import_path, *partial_args, **partial_kwargs)
+ """
+ # normalise aliases
+ if aliases is None: aliases = ()
+ aliases = self._normalise_aliases(aliases, check_nonempty=False)
+ # add object alias
+ (*_, alias) = _check_and_split_path(import_path)
+ aliases = self._append_auto_alias(alias, aliases=aliases, auto_alias=auto_alias)
+ # register the lazy import
+ self[aliases] = LazyImport(import_path=import_path, *partial_args, **partial_kwargs)
+
+ # --- ALIAS HELPER --- #
+
+ def _append_auto_alias(self, alias: Optional[str], aliases: Tuple[str, ...], auto_alias: bool):
+ if auto_alias:
+ if alias is not None:
+ if alias not in self:
+ aliases = (alias, *aliases)
+ elif not aliases:
+ raise RuntimeError(f'automatic alias: {repr(alias)} already exists for registry: {repr(self.name)} and no alternative aliases were specified.')
+ elif not aliases:
+ raise RuntimeError(f'Cannot add value to registry: {repr(self.name)}, no automatic alias was found!')
+ elif not aliases:
+ raise RuntimeError(f'Cannot add value to registry: {repr(self.name)}, no manual aliases were specified and automatic aliasing is disabled!')
+ return aliases
- def __setitem__(self, aliases: str, value: T):
- if isinstance(aliases, str):
- aliases = (aliases,)
- assert isinstance(aliases, tuple), f'multiple aliases must be provided as a Tuple[str], got: {repr(aliases)}'
- self.register_value(value=value, aliases=aliases)
-
- def __contains__(self, key: str):
- return key in self._keys_to_values
-
- def __getitem__(self, key: str):
- if key not in self._keys_to_values:
- raise KeyError(f'registry does not contain the key: {repr(key)}, valid keys include: {sorted(self._keys_to_values.keys())}')
- # get the value
- value = self._keys_to_values[key]
- # replace lazy value
- if isinstance(value, LazyValue):
- value = value.generate()
- # check value & run deferred checks
- if isinstance(value, LazyValue):
- raise RuntimeError(f'{LazyValue.__name__} should never return other lazy values.')
- self.assert_valid_value(value)
- # update the value
- self._keys_to_values[key] = value
- # return the value
- return value
+ @staticmethod
+ def _get_fn_alias(fn) -> Optional[str]:
+ if hasattr(fn, '__name__'):
+ if str.isidentifier(fn.__name__):
+ return fn.__name__
+ return None
- def __iter__(self):
- yield from self._keys_to_values.keys()
+ # --- OVERRIDDEN --- #
- def assert_valid_value(self, value: T) -> T:
- if self._assert_valid_value is not None:
- self._assert_valid_value(value)
- return value
+ def _check_and_normalise_value(self, v: ProvidedValue[V]) -> ProvidedValue[V]:
+ if not isinstance(v, (LazyImport, StaticImport)):
+ raise TypeError(f'Values stored in registry: {repr(self.name)} must be instances of: {(LazyImport.__name__, StaticImport.__name__)}, got: {repr(v)}')
+ return v
- def assert_valid_key(self, key: str) -> str:
- if self._assert_valid_key is not None:
- self._assert_valid_key(key)
- return key
+ # --- OVERRIDABLE --- #
- def __repr__(self):
- return f'{self.__class__.__name__}({repr(self._name)}, ...)'
+ def _check_provided_value(self, v: V) -> NoReturn:
+ pass
+
+ def _check_key(self, k: str) -> NoReturn:
+ pass
+
+
+# ========================================================================= #
+# Dynamic Registry #
+# ========================================================================= #
+
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+# Constructor #
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+class RegexConstructor(object):
+
+ def __init__(
+ self,
+ pattern: Union[str, re.Pattern],
+ example: str,
+ factory_fn: Union[_FactoryFn[V], str],
+ ):
+ self._pattern = self._check_pattern(pattern)
+ self._example = self._check_example(example, self._pattern)
+ # we can delay loading of the function if it is a string!
+ self._factory_fn = factory_fn if isinstance(factory_fn, str) else self._check_factory_fn(factory_fn, self._pattern)
+
+ @classmethod
+ def _check_pattern(cls, pattern: Union[str, re.Pattern]):
+ # check the regex type & convert
+ if isinstance(pattern, str):
+ pattern = re.compile(pattern)
+ if not isinstance(pattern, re.Pattern):
+ raise TypeError(f'regex pattern must be a regex `str` or `re.Pattern`, got: {repr(pattern)}')
+ if pattern.groups < 1:
+ raise ValueError(f'regex pattern must contain at least one group, got: {repr(pattern)}')
+ return pattern
+
+ @classmethod
+ def _check_factory_fn(cls, factory_fn: _FactoryFn[V], pattern: re.Pattern) -> _FactoryFn[V]:
+ # we have an actual function, we can check it!
+ if not callable(factory_fn):
+ raise TypeError(f'generator function must be callable, got: {factory_fn}')
+ signature = inspect.signature(factory_fn)
+ if len(signature.parameters) != pattern.groups:
+ raise ValueError(f'signature has incorrect number of parameters: {repr(signature)} compared to the number of groups in the regex pattern: {repr(pattern)}')
+ return factory_fn
+
+ @classmethod
+ def _check_example(cls, example: str, pattern: re.Pattern) -> str:
+ # check the example
+ if not isinstance(example, str):
+ raise TypeError(f'example must be a `str`, got: {type(example)}')
+ if not example:
+ raise ValueError(f'example must not be empty, got: {type(example)}')
+ # check that the regex matches the example!
+ if pattern.search(example) is None:
+ raise ValueError(f'could not match example: {repr(example)} to regex: {repr(pattern)}')
+ return example
+
+ @property
+ def pattern(self) -> re.Pattern:
+ return self._pattern
+
+ @property
+ def example(self) -> str:
+ return self._example
+
+ def construct(self, name: str) -> V:
+ # get the results
+ result = self._pattern.search(name)
+ if result is None:
+ raise KeyError(f'pattern: {self.pattern} does not match given name: {repr(name)}. The following example would be valid: {repr(self.example)}')
+ # get the function -- load via the path
+ if isinstance(self._factory_fn, str):
+ fn = import_obj(self._factory_fn)
+ self._factory_fn = self._check_factory_fn(fn, self._pattern)
+ # construct
+ return self._factory_fn(*result.groups())
+
+ def can_construct(self, name: str) -> bool:
+ return self._pattern.search(name) is not None
+
+
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+# Cached Linear Search #
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+class RegexProvidersSearch(object):
+
+ def __init__(self):
+ self._patterns: Set[re.Pattern] = set()
+ self._constructors: List[RegexConstructor] = []
+ # caching
+ self._cache = {}
+ self._cache_dirty = False
+
+ @property
+ def regex_constructors(self) -> List[RegexConstructor]:
+ return list(self._constructors)
+
+ def construct(self, arg_str: str):
+ provider = self.get_constructor(arg_str)
+ # build the object
+ if provider is not None:
+ return provider.construct(arg_str)
+ # no result was found!
+ raise KeyError(f'could not construct an item from the given argument string: {repr(arg_str)}, valid patterns include: {[p.pattern for p in self._constructors]}')
+
+ def can_construct(self, arg_str: str) -> bool:
+ return self.get_constructor(arg_str) is not None
+
+ def get_constructor(self, arg_str: str) -> Optional[RegexConstructor]:
+ # TODO: clean up this cache!
+ # check cache -- remove None entries if dirty
+ if self._cache_dirty:
+ self._cache = {k: v for k, v in self._cache.items() if v is not None}
+ self._cache_dirty = False
+ if arg_str in self._cache:
+ return self._cache[arg_str]
+ # check the input string
+ if not isinstance(arg_str, str):
+ raise TypeError(f'regex factory can only construct from `str`, got: {repr(arg_str)}')
+ if not arg_str:
+ raise ValueError(f'regex factory can only construct from non-empty `str`, got: {repr(arg_str)}')
+ # match the values
+ constructor = None
+ for c in self:
+ if c.can_construct(arg_str):
+ constructor = c
+ break
+ # cache the value
+ self._cache[arg_str] = constructor
+ if len(self._cache) > 128:
+ self._cache.popitem()
+ return constructor
+
+ def has_pattern(self, pattern: Union[str, re.Pattern]) -> bool:
+ if isinstance(pattern, str):
+ pattern = re.compile(pattern)
+ return pattern in self._patterns
+
+ def __len__(self) -> int:
+ return len(self._constructors)
+
+ def __iter__(self) -> Iterator[RegexConstructor]:
+ yield from self._constructors
+
+ def append(self, constructor: RegexConstructor):
+ if not isinstance(constructor, RegexConstructor):
+ raise TypeError(f'regex factory only accepts {RegexConstructor.__name__} providers.')
+ if constructor.pattern in self._patterns:
+ raise RuntimeError(f'regex factory already contains the regex pattern: {repr(constructor.pattern)}')
+ # append value!
+ self._patterns.add(constructor.pattern)
+ self._constructors.append(constructor)
+ self._cache_dirty = True
+
+
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+# Dynamic Registry #
+# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+class RegexRegistry(Registry[V]):
+
+ """
+ Registry that allows registering of regex expressions that can be used to
+ construct values if there is no static value found!
+ - Regular expressions are checked in the order they are registered.
+ - `name in registry` checks if any of the expression matches, it does not check for an existing regex
+ - `len(registry)` returns the number of examples available, each item & regex factory
+ - `for example in registry` returns the examples available, each item & regex factory should be called if we use these to construct values `registry[example]`
+
+ To check for an already added regex expression, use:
+ - `has_regex(expr)`
+ """
+
+ def __init__(self, name: str):
+ self._regex_providers = RegexProvidersSearch()
+ super().__init__(name)
+
+ # --- CORE ... UPDATED WITH LINEAR SEARCH --- #
+
+ @property
+ def regex_constructors(self) -> List[RegexConstructor]:
+ return self._regex_providers.regex_constructors
+
+ @property
+ def regex_examples(self) -> List[str]:
+ return [constructor.example for constructor in self._regex_providers.regex_constructors]
+
+ @property
+ def examples(self) -> List[str]:
+ return [*self.static_examples, *self.regex_examples]
+
+ def __getitem__(self, k: str) -> V:
+ assert isinstance(k, str), f'invalid key: {repr(k)}, must be a `str`'
+ # the regex provider is cached so this should be efficient for the same value calls
+ # -- we do not cache the actual provided value!
+ if k in self._providers:
+ return self._getitem(k)
+ elif self._regex_providers.can_construct(k):
+ return self._regex_providers.construct(k)
+ raise KeyError(f'dynamic registry: {repr(self.name)} cannot construct item with key: {repr(k)}. Valid static values: {sorted(self._providers.keys())}. Valid dynamic examples: {[p.example for p in self._regex_providers]}')
+
+ def __setitem__(self, aliases: AliasesHint, v: ProvidedValue[V]) -> NoReturn:
+ if isinstance(aliases, re.Pattern) or isinstance(v, RegexConstructor):
+ raise RuntimeError(f'register dynamic values to the dynamic registry: {repr(self.name)} with the `register_regex` or `register_constructor` methods.')
+ super().__setitem__(aliases, v)
+
+ def __contains__(self, k: K):
+ if k in self._providers:
+ return True
+ if self._regex_providers.can_construct(k):
+ return True
+ return False
+
+ def __len__(self) -> int:
+ return len(self._providers) + len(self._regex_providers)
+
+ def __iter__(self) -> Iterator[K]:
+ yield from self._providers
+ yield from (p.example for p in self._regex_providers)
+
+ # --- OVERRIDABLE --- #
+
+ def _check_regex_constructor(self, constructor: RegexConstructor):
+ pass
+
+ # --- DYNAMIC VALUES --- #
+
+ def has_regex(self, pattern: Union[str, re.Pattern]) -> bool:
+ return self._regex_providers.has_pattern(pattern)
+
+ def register_constructor(self, constructor: RegexConstructor) -> 'RegexRegistry':
+ """
+ Register a regex constructor
+ """
+ if not isinstance(constructor, RegexConstructor):
+ raise TypeError(f'dynamic registry: {repr(self.name)} only accepts dynamic {RegexConstructor.__name__}, got: {repr(constructor)}')
+ self._check_regex_constructor(constructor)
+ self._regex_providers.append(constructor)
+ return self
+
+ def register_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None):
+ """
+ Register and create a regex constructor
+ """
+ def _register_wrapper(fn: T) -> T:
+ self.register_constructor(RegexConstructor(pattern=pattern, example=example, factory_fn=fn))
+ return fn
+ return _register_wrapper if (factory_fn is None) else _register_wrapper(factory_fn)
+
+ def register_missing_constructor(self, constructor: RegexConstructor):
+ """
+ Only register a regex constructor if the pattern does not already exist!
+ """
+ if not self.has_regex(constructor.pattern):
+ return self.register_constructor(constructor)
+
+ def register_missing_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None):
+ """
+ Only register and create a regex constructor if the pattern does not already exist!
+ """
+ if not self.has_regex(pattern):
+ return self.register_regex(pattern=pattern, example=example, factory_fn=factory_fn)
+ elif factory_fn is None:
+ return lambda fn: fn # dummy wrapper
+
+ # --- MISSING VALUES --- #
+
+ # override from the parent class!
+ class _RegistrySetMissing(Registry._RegistrySetMissing):
+
+ _registry: 'RegexRegistry'
+
+ def register_constructor(self, constructor: RegexConstructor):
+ """
+ Only register a regex constructor if the pattern does not already exist!
+ """
+ return self._registry.register_missing_constructor(constructor=constructor)
+
+ def register_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None):
+ """
+ Only register and create a regex constructor if the pattern does not already exist!
+ """
+ return self._registry.register_missing_regex(pattern=pattern, example=example, factory_fn=factory_fn, )
# ========================================================================= #
diff --git a/disent/schedule/__init__.py b/disent/schedule/__init__.py
index c9ea8030..f20d4842 100644
--- a/disent/schedule/__init__.py
+++ b/disent/schedule/__init__.py
@@ -30,12 +30,17 @@
from ._schedule import CyclicSchedule
from ._schedule import LinearSchedule
from ._schedule import NoopSchedule
+from ._schedule import MultiplySchedule
+from ._schedule import FixedValueSchedule
from ._schedule import SingleSchedule
+
# aliases
from ._schedule import ClipSchedule as Clip
from ._schedule import CosineWaveSchedule as CosineWave
from ._schedule import CyclicSchedule as Cyclic
from ._schedule import LinearSchedule as Linear
from ._schedule import NoopSchedule as Noop
+from ._schedule import MultiplySchedule as Multiply
+from ._schedule import FixedValueSchedule as FixedValue
from ._schedule import SingleSchedule as Single
diff --git a/disent/schedule/_schedule.py b/disent/schedule/_schedule.py
index 41db86d5..13632676 100644
--- a/disent/schedule/_schedule.py
+++ b/disent/schedule/_schedule.py
@@ -58,6 +58,52 @@ def compute_value(self, step: int, value):
return value
+class MultiplySchedule(Schedule):
+ """
+ A schedule that always applies a constant multiplier/ratio to the input value
+ """
+
+ # This schedule will always return a constant value!
+
+ def __init__(self, r: float = 1.0):
+ """
+ :param r: The constant ratio of the original value that the schedule will use
+ """
+ self.r = r
+
+ def compute_value(self, step: int, value):
+ # we always return a constant value!
+ return value * self.r
+
+
+class FixedValueSchedule(Schedule):
+ """
+ Set a new constant value, instead of using the value passed to `compute_value`.
+ - We can override config values using this class
+ """
+
+ def __init__(
+ self,
+ value: float,
+ schedule: Optional[Schedule] = None,
+ ):
+ """
+ :param schedule: The wrapped schedule that is passed the new constant value
+ :param value: The value that should be used to replace the original values from the config.
+ If `compute_value` is called, the value passed to the function is replaced with this one!
+ """
+ assert (schedule is None) or isinstance(schedule, Schedule)
+ self.schedule = schedule
+ self.value = value
+
+ def compute_value(self, step: int, value):
+ del value
+ # we override the passed value, and pass in a constant value instead!
+ if self.schedule is None:
+ return self.value
+ else:
+ return self.schedule(step, self.value)
+
# ========================================================================= #
# Value Schedules #
# ========================================================================= #
diff --git a/disent/util/deprecate.py b/disent/util/deprecate.py
index 96fc654c..75b9a8a8 100644
--- a/disent/util/deprecate.py
+++ b/disent/util/deprecate.py
@@ -69,7 +69,7 @@ def _get_stack_file_strings():
DEFAULT_TRACEBACK_MODE = 'first'
-def deprecated(msg: str, traceback_mode: Optional[str] = None):
+def deprecated(msg: str, traceback_mode: Optional[str] = None, fn=None):
"""
Mark a function or class as deprecated, and print a warning the
first time it is used.
@@ -114,7 +114,12 @@ def _caller(*args, **kwargs):
else:
fn = _caller
return fn
- return _decorator
+
+ # handle function used as decorator, or called directly
+ if fn is not None:
+ return _decorator(fn)
+ else:
+ return _decorator
# ========================================================================= #
diff --git a/disent/util/inout/paths.py b/disent/util/inout/paths.py
index 1ad76ded..7f6ea5ac 100644
--- a/disent/util/inout/paths.py
+++ b/disent/util/inout/paths.py
@@ -42,10 +42,33 @@ def modify_file_name(file: Union[str, Path], prefix: str = None, suffix: str = N
path = Path(file)
assert path.name, f'file name cannot be empty: {repr(path)}, for name: {repr(path.name)}'
# create new path
- prefix = '' if (prefix is None) else f'{prefix}{sep}'
- suffix = '' if (suffix is None) else f'{sep}{suffix}'
+ prefix = '' if (not prefix) else f'{prefix}{sep}'
+ suffix = '' if (not suffix) else f'{sep}{suffix}'
new_path = path.parent.joinpath(f'{prefix}{path.name}{suffix}')
- # return path
+ # return path with same format as input
+ return str(new_path) if isinstance(file, str) else new_path
+
+
+def modify_name_keep_ext(file: Union[str, Path], prefix: str = None, suffix: str = None, name_contains_sep: bool = False):
+ # get path components
+ path = Path(file)
+ name = path.name
+ assert name, f'file name cannot be empty: {repr(path)}, for name: {repr(name)}'
+ # handle suffix
+ if suffix:
+ components = name.rsplit('.', 1) if name_contains_sep else path.name.split('.', 1)
+ if len(components) >= 2:
+ [name, ext] = components
+ name = f'{name}{suffix}.{ext}'
+ else:
+ [name] = components
+ name = f'{name}{suffix}'
+ # handle prefix
+ if prefix:
+ name = f'{prefix}{name}'
+ # create new path
+ new_path = path.parent.joinpath(name)
+ # return path with same format as input
return str(new_path) if isinstance(file, str) else new_path
diff --git a/disent/util/jit.py b/disent/util/jit.py
new file mode 100644
index 00000000..38083628
--- /dev/null
+++ b/disent/util/jit.py
@@ -0,0 +1,53 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+
+# ========================================================================= #
+# Numba Is An Optional Dependency #
+# ========================================================================= #
+
+
+def try_njit(*args, **kwargs):
+ """
+ Wrapper around numba.njit
+ - If numba is installed, then we JIT the decorated function
+ - If numba is missing, then we do nothing and leave the function untouched!
+ """
+ try:
+ from numba import njit
+ except ImportError:
+ # dummy njit
+ def njit(*args, **kwargs):
+ def _wrapper(func):
+ import warnings
+ warnings.warn(f'failed to JIT compile: {func}, numba is not installed!')
+ return func
+ return _wrapper
+ # try and JIT compile function!
+ return njit(*args, **kwargs)
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/lightning/callbacks/__init__.py b/disent/util/lightning/callbacks/__init__.py
index 60b1dff4..ec300e21 100644
--- a/disent/util/lightning/callbacks/__init__.py
+++ b/disent/util/lightning/callbacks/__init__.py
@@ -25,8 +25,7 @@
from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic
from disent.util.lightning.callbacks._callbacks_base import BaseCallbackTimed
-from disent.util.lightning.callbacks._callbacks_pl import LoggerProgressCallback
-
-from disent.util.lightning.callbacks._callbacks_vae import VaeMetricLoggingCallback
-from disent.util.lightning.callbacks._callbacks_vae import VaeLatentCycleLoggingCallback
-from disent.util.lightning.callbacks._callbacks_vae import VaeGtDistsLoggingCallback
+from disent.util.lightning.callbacks._callback_print_progress import LoggerProgressCallback
+from disent.util.lightning.callbacks._callback_log_metrics import VaeMetricLoggingCallback
+from disent.util.lightning.callbacks._callback_vis_latents import VaeLatentCycleLoggingCallback
+from disent.util.lightning.callbacks._callback_vis_dists import VaeGtDistsLoggingCallback
diff --git a/disent/util/lightning/callbacks/_callback_log_metrics.py b/disent/util/lightning/callbacks/_callback_log_metrics.py
new file mode 100644
index 00000000..b2fafead
--- /dev/null
+++ b/disent/util/lightning/callbacks/_callback_log_metrics.py
@@ -0,0 +1,129 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+import warnings
+from typing import Optional
+from typing import Sequence
+
+import pytorch_lightning as pl
+
+from disent import registry as R
+from disent.dataset.data import GroundTruthData
+from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic
+from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like
+from disent.util.lightning.logger_util import log_metrics
+from disent.util.lightning.logger_util import wb_log_reduced_summaries
+from disent.util.profiling import Timer
+from disent.util.strings import colors as c
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# Helper #
+# ========================================================================= #
+
+
+def _normalized_numeric_metrics(items: dict):
+ results = {}
+ for k, v in items.items():
+ if isinstance(v, (float, int)):
+ results[k] = v
+ else:
+ try:
+ results[k] = float(v)
+ except:
+ log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}')
+ return results
+
+
+# ========================================================================= #
+# Metrics Callback #
+# ========================================================================= #
+
+
+class VaeMetricLoggingCallback(BaseCallbackPeriodic):
+
+ def __init__(
+ self,
+ step_end_metrics: Optional[Sequence[str]] = None,
+ train_end_metrics: Optional[Sequence[str]] = None,
+ every_n_steps: Optional[int] = None,
+ begin_first_step: bool = False,
+ ):
+ super().__init__(every_n_steps, begin_first_step)
+ self.step_end_metrics = step_end_metrics if step_end_metrics else []
+ self.train_end_metrics = train_end_metrics if train_end_metrics else []
+ assert isinstance(self.step_end_metrics, list)
+ assert isinstance(self.train_end_metrics, list)
+ assert self.step_end_metrics or self.train_end_metrics, 'No metrics given to step_end_metrics or train_end_metrics'
+
+ def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningModule, metrics: list, is_final=False):
+ # get dataset and vae framework from trainer and module
+ dataset, vae = _get_dataset_and_ae_like(trainer, pl_module, unwrap_groundtruth=True)
+ # check if we need to skip
+ # TODO: dataset needs to be able to handle wrapped datasets!
+ if not dataset.is_ground_truth:
+ warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!')
+ return
+ # get padding amount
+ pad = max(7+len(k) for k in R.METRICS) # I know this is a magic variable... im just OCD
+ # compute all metrics
+ for metric in metrics:
+ if is_final:
+ log.info(f'| {metric.__name__:<{pad}} - computing...')
+ with Timer() as timer:
+ scores = metric(dataset, lambda x: vae.encode(x.to(vae.device)))
+ metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items())
+ log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}')
+
+ # log to trainer
+ prefix = 'final_metric' if is_final else 'epoch_metric'
+ prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()}
+ log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores))
+
+ # log summary for WANDB
+ # this is kinda hacky... the above should work for parallel coordinate plots
+ wb_log_reduced_summaries(trainer.logger, prefixed_scores, reduction='max')
+
+ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+ if self.step_end_metrics:
+ log.debug('Computing Epoch Metrics:')
+ with Timer() as timer:
+ self._compute_metrics_and_log(trainer, pl_module, metrics=self.step_end_metrics, is_final=False)
+ log.debug(f'Computed Epoch Metrics! {timer.pretty}')
+
+ def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+ if self.train_end_metrics:
+ log.debug('Computing Final Metrics...')
+ with Timer() as timer:
+ self._compute_metrics_and_log(trainer, pl_module, metrics=self.train_end_metrics, is_final=True)
+ log.debug(f'Computed Final Metrics! {timer.pretty}')
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/lightning/callbacks/_callbacks_pl.py b/disent/util/lightning/callbacks/_callback_print_progress.py
similarity index 78%
rename from disent/util/lightning/callbacks/_callbacks_pl.py
rename to disent/util/lightning/callbacks/_callback_print_progress.py
index 5e8086aa..f2ee345e 100644
--- a/disent/util/lightning/callbacks/_callbacks_pl.py
+++ b/disent/util/lightning/callbacks/_callback_print_progress.py
@@ -23,7 +23,6 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
import logging
-import warnings
import pytorch_lightning as pl
@@ -39,7 +38,11 @@
class LoggerProgressCallback(BaseCallbackTimed):
-
+
+ def __init__(self, interval: float = 10, log_level: int = logging.INFO):
+ super().__init__(interval=interval)
+ self._log_level = log_level
+
def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, current_time, start_time):
# get missing vars
trainer_max_epochs = trainer.max_epochs if (trainer.max_epochs is not None) else float('inf')
@@ -50,37 +53,43 @@ def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, curren
max_epochs = min(trainer_max_epochs, (trainer_max_steps + max_batches - 1) // max_batches)
max_steps = min(trainer_max_epochs * max_batches, trainer_max_steps)
elapsed_sec = current_time - start_time
+
# get vars
global_step = trainer.global_step + 1
epoch = trainer.current_epoch + 1
if hasattr(trainer, 'batch_idx'):
batch = (trainer.batch_idx + 1)
else:
- # TODO: re-enable this warning but only ever print once!
- # warnings.warn('batch_idx missing on pl.Trainer')
+ # warnings.warn('batch_idx missing on pl.Trainer') # TODO: re-enable this warning but only ever print once!
batch = global_step % max_batches # might not be int?
+
# completion
train_pct = global_step / max_steps
train_remain_time = elapsed_sec * (1 - train_pct) / train_pct # seconds
+
# get speed -- TODO: make this a moving average?
if global_step >= elapsed_sec:
step_speed_str = f'{global_step / elapsed_sec:4.2f}it/s'
else:
step_speed_str = f'{elapsed_sec / global_step:4.2f}s/it'
- # info dict
+
+ # get the metrics and format them
info_dict = {
k: f'{v:.4g}' if isinstance(v, (int, float)) else f'{v}'
- for k, v in trainer.progress_bar_dict.items()
- if k != 'v_num'
+ for k, v in trainer.progress_bar_metrics.items()
}
+
+ # sort the keys placing loss entries first
sorted_k = sorted(info_dict.keys(), key=lambda k: ('loss' != k.lower(), 'loss' not in k.lower(), k))
- # log
- log.info(
- f'[{int(elapsed_sec)}s, {step_speed_str}] '
- + f'EPOCH: {epoch}/{max_epochs} - {int(global_step):0{len(str(max_steps))}d}/{max_steps} '
- + f'({int(train_pct * 100):02d}%) [rem. {int(train_remain_time)}s] '
- + f'STEP: {int(batch):{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) '
- + f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}'
+
+ # log everything
+ log.log(
+ level=self._log_level,
+ msg=f'[{int(elapsed_sec)}s, {step_speed_str}] '
+ + f'EPOCH: {epoch}/{max_epochs} - {int(global_step):0{len(str(max_steps))}d}/{max_steps} '
+ + f'({int(train_pct * 100):02d}%) [rem. {int(train_remain_time)}s] '
+ + f'STEP: {int(batch):{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) '
+ + f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}'
)
diff --git a/disent/util/lightning/callbacks/_callback_vis_dists.py b/disent/util/lightning/callbacks/_callback_vis_dists.py
new file mode 100644
index 00000000..aa33646f
--- /dev/null
+++ b/disent/util/lightning/callbacks/_callback_vis_dists.py
@@ -0,0 +1,329 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2021 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+
+# TODO: wandb and matplotlib are not in requirements
+import matplotlib.pyplot as plt
+import wandb
+
+import disent.util.strings.colors as c
+from disent.dataset import DisentDataset
+from disent.dataset.data import GroundTruthData
+from disent.frameworks.ae import Ae
+from disent.frameworks.vae import Vae
+from disent.util.function import wrapped_partial
+from disent.util.iters import chunked
+from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic
+from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like
+from disent.util.lightning.logger_util import wb_log_metrics
+from disent.util.profiling import Timer
+from disent.util.seeds import TempNumpySeed
+from disent.util.visualize.plot import plt_subplots_imshow
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# Helper Functions #
+# ========================================================================= #
+
+
+# helper
+def _to_dmat(
+ size: int,
+ i_a: np.ndarray,
+ i_b: np.ndarray,
+ dists: Union[torch.Tensor, np.ndarray],
+) -> np.ndarray:
+ if isinstance(dists, torch.Tensor):
+ dists = dists.detach().cpu().numpy()
+ # checks
+ assert i_a.ndim == 1
+ assert i_a.shape == i_b.shape
+ assert i_a.shape == dists.shape
+ # compute
+ dmat = np.zeros([size, size], dtype='float32')
+ dmat[i_a, i_b] = dists
+ dmat[i_b, i_a] = dists
+ return dmat
+
+
+_AE_DIST_NAMES = ('x', 'z', 'x_recon')
+_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon')
+
+
+@torch.no_grad()
+def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor):
+ # feed forware
+ z_a, z_b = ae.encode(x_a), ae.encode(x_b)
+ r_a, r_b = ae.decode(z_a), ae.decode(z_b)
+ # distances
+ return [
+ ae.recon_handler.compute_pairwise_loss(x_a, x_b),
+ torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
+ ae.recon_handler.compute_pairwise_loss(r_a, r_b),
+ ]
+
+
+@torch.no_grad()
+def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor):
+ from torch.distributions import kl_divergence
+ # feed forward
+ (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b)
+ z_a, z_b = z_post_a.mean, z_post_b.mean
+ r_a, r_b = vae.decode(z_a), vae.decode(z_b)
+ # dists
+ kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a)
+ # distances
+ return [
+ vae.recon_handler.compute_pairwise_loss(x_a, x_b),
+ torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
+ vae.recon_handler._pairwise_reduce(kl_ab),
+ vae.recon_handler.compute_pairwise_loss(r_a, r_b),
+ ]
+
+
+def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]:
+ # get aggregate function
+ if isinstance(model, Vae):
+ dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model)
+ elif isinstance(model, Ae):
+ dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model)
+ else:
+ dists_names, dists_fn = None, None
+ return dists_names, dists_fn
+
+
+@torch.no_grad()
+def _collect_dists_subbatches(dists_fn: Callable[[object, object], Sequence[Sequence[float]]], batch: torch.Tensor, i_a: np.ndarray, i_b: np.ndarray, batch_size: int = 64):
+ # feed forward
+ results = []
+ for idxs in chunked(np.stack([i_a, i_b], axis=-1), chunk_size=batch_size):
+ ia, ib = idxs.T
+ x_a, x_b = batch[ia], batch[ib]
+ # feed forward
+ data = dists_fn(x_a, x_b)
+ results.append(data)
+ return [torch.cat(r, dim=0) for r in zip(*results)]
+
+
+def _compute_and_collect_dists(
+ dataset: DisentDataset,
+ dists_fn,
+ dists_names: Sequence[str],
+ traversal_repeats: int = 100,
+ batch_size: int = 32,
+ include_gt_factor_dists: bool = True,
+ transform_batch: Callable[[object], object] = None,
+ data_mode: str = 'input',
+) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]:
+ assert traversal_repeats > 0
+ gt_data = dataset.gt_data
+ # generate
+ f_grid = []
+ # generate
+ for f_idx, f_size in enumerate(gt_data.factor_sizes):
+ # save for the current factor (traversal_repeats, len(names), len(i_a))
+ f_dists = []
+ # upper triangle excluding diagonal
+ i_a, i_b = np.triu_indices(f_size, k=1)
+ # repeat over random traversals
+ for i in range(traversal_repeats):
+ # get random factor traversal
+ factors = gt_data.sample_random_factor_traversal(f_idx=f_idx)
+ indices = gt_data.pos_to_idx(factors)
+ # load data
+ batch = dataset.dataset_batch_from_indices(indices, data_mode)
+ if transform_batch is not None:
+ batch = transform_batch(batch)
+ # feed forward & compute dists -- (len(names), len(i_a))
+ dists = _collect_dists_subbatches(dists_fn=dists_fn, batch=batch, i_a=i_a, i_b=i_b, batch_size=batch_size)
+ assert len(dists) == len(dists_names)
+ # distances
+ f_dists.append(dists)
+ # aggregate all dists into distances matrices for current factor
+ f_dmats = [
+ _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=torch.stack(dists, dim=0).mean(dim=0))
+ for dists in zip(*f_dists)
+ ]
+ # handle factors
+ if include_gt_factor_dists:
+ i_dmat = _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=np.abs(factors[i_a] - factors[i_b]).sum(axis=-1))
+ f_dmats = [i_dmat, *f_dmats]
+ # append data
+ f_grid.append(f_dmats)
+ # handle factors
+ if include_gt_factor_dists:
+ dists_names = ('factors', *dists_names)
+ # done
+ return tuple(dists_names), f_grid
+
+
+def compute_factor_distances(
+ dataset: DisentDataset,
+ dists_fn,
+ dists_names: Sequence[str],
+ traversal_repeats: int = 100,
+ batch_size: int = 32,
+ include_gt_factor_dists: bool = True,
+ transform_batch: Callable[[object], object] = None,
+ seed: Optional[int] = 777,
+ data_mode: str = 'input',
+) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]:
+ # log this callback
+ gt_data = dataset.gt_data
+ log.info(f'| {gt_data.name} - computing factor distances...')
+ # compute various distances matrices for each factor
+ with Timer() as timer, TempNumpySeed(seed):
+ dists_names, f_grid = _compute_and_collect_dists(
+ dataset=dataset,
+ dists_fn=dists_fn,
+ dists_names=dists_names,
+ traversal_repeats=traversal_repeats,
+ batch_size=batch_size,
+ include_gt_factor_dists=include_gt_factor_dists,
+ transform_batch=transform_batch,
+ data_mode=data_mode,
+ )
+ # log this callback!
+ log.info(f'| {gt_data.name} - computed factor distances! time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST}')
+ return dists_names, f_grid
+
+
+def plt_factor_distances(
+ gt_data: GroundTruthData,
+ f_grid: List[List[np.ndarray]],
+ dists_names: Sequence[str],
+ title: str,
+ plt_block_size: float = 1.25,
+ plt_transpose: bool = False,
+ plt_cmap='Blues',
+):
+ # plot information
+ imshow_kwargs = dict(cmap=plt_cmap)
+ figsize = (plt_block_size*len(f_grid[0]), plt_block_size * gt_data.num_factors)
+ # plot!
+ if not plt_transpose:
+ fig, axs = plt_subplots_imshow(grid=f_grid, col_labels=dists_names, row_labels=gt_data.factor_names, figsize=figsize, title=title, imshow_kwargs=imshow_kwargs)
+ else:
+ fig, axs = plt_subplots_imshow(grid=list(zip(*f_grid)), col_labels=gt_data.factor_names, row_labels=dists_names, figsize=figsize[::-1], title=title, imshow_kwargs=imshow_kwargs)
+ # done
+ return fig, axs
+
+
+# ========================================================================= #
+# Data Dists Visualisation Callback #
+# ========================================================================= #
+
+
+class VaeGtDistsLoggingCallback(BaseCallbackPeriodic):
+
+ def __init__(
+ self,
+ seed: Optional[int] = 7777,
+ every_n_steps: Optional[int] = None,
+ traversal_repeats: int = 100,
+ begin_first_step: bool = False,
+ plt_block_size: float = 1.25,
+ plt_show: bool = False,
+ plt_transpose: bool = False,
+ log_wandb: bool = True, # TODO: detect this automatically?
+ batch_size: int = 128,
+ include_factor_dists: bool = True,
+ ):
+ assert traversal_repeats > 0
+ self._traversal_repeats = traversal_repeats
+ self._seed = seed
+ self._plt_block_size = plt_block_size
+ self._plt_show = plt_show
+ self._log_wandb = log_wandb
+ self._include_gt_factor_dists = include_factor_dists
+ self._transpose_plot = plt_transpose
+ self._batch_size = batch_size
+ super().__init__(every_n_steps, begin_first_step)
+
+ @torch.no_grad()
+ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+ # exit early
+ if not (self._plt_show or self._log_wandb):
+ log.warning(f'skipping {self.__class__.__name__} neither `plt_show` or `log_wandb` is `True`!')
+ return
+ # get dataset and vae framework from trainer and module
+ dataset, vae = _get_dataset_and_ae_like(trainer, pl_module, unwrap_groundtruth=True)
+ # exit early
+ if not dataset.is_ground_truth:
+ log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!')
+ return
+ # get aggregate function
+ dists_names, dists_fn = _get_dists_fn(vae)
+ if (dists_names is None) or (dists_fn is None):
+ log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
+ return
+ # compute various distances matrices for each factor
+ dists_names, f_grid = compute_factor_distances(
+ dataset=dataset,
+ dists_fn=dists_fn,
+ dists_names=dists_names,
+ traversal_repeats=self._traversal_repeats,
+ batch_size=self._batch_size,
+ include_gt_factor_dists=self._include_gt_factor_dists,
+ transform_batch=lambda batch: batch.to(vae.device),
+ seed=self._seed,
+ data_mode='input',
+ )
+ # plot these results
+ fig, axs = plt_factor_distances(
+ gt_data=dataset.gt_data,
+ f_grid=f_grid,
+ dists_names=dists_names,
+ title=f'{vae.__class__.__name__}: {dataset.gt_data.name.capitalize()} Distances',
+ plt_block_size=self._plt_block_size,
+ plt_transpose=self._transpose_plot,
+ plt_cmap='Blues',
+ )
+ # show the plot
+ if self._plt_show:
+ plt.show()
+ # log the plot to wandb
+ if self._log_wandb:
+ wb_log_metrics(trainer.logger, {
+ 'factor_distances': wandb.Image(fig)
+ })
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/lightning/callbacks/_callback_vis_latents.py b/disent/util/lightning/callbacks/_callback_vis_latents.py
new file mode 100644
index 00000000..a33aae0c
--- /dev/null
+++ b/disent/util/lightning/callbacks/_callback_vis_latents.py
@@ -0,0 +1,336 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+
+from disent.dataset import DisentDataset
+from disent.frameworks.ae import Ae
+from disent.frameworks.vae import Vae
+from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic
+from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like
+from disent.util.lightning.logger_util import wb_log_metrics
+from disent.util.seeds import TempNumpySeed
+from disent.util.visualize.vis_img import torch_to_images
+from disent.util.visualize.vis_latents import make_decoded_latent_cycles
+from disent.util.visualize.vis_util import make_animated_image_grid
+from disent.util.visualize.vis_util import make_image_grid
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# Helper #
+# ========================================================================= #
+
+
+MinMaxHint = Optional[Union[int, Literal['auto']]]
+MeanStdHint = Optional[Union[Tuple[float, ...], float]]
+
+
+def get_vis_min_max(
+ recon_min: MinMaxHint = None,
+ recon_max: MinMaxHint = None,
+ recon_mean: MeanStdHint = None,
+ recon_std: MeanStdHint = None,
+) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+ # check recon_min and recon_max
+ if (recon_min is not None) or (recon_max is not None):
+ if (recon_mean is not None) or (recon_std is not None):
+ raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both')
+ if (recon_min is None) or (recon_max is None):
+ raise ValueError('both recon_min & recon_max must be specified')
+ # check strings
+ if isinstance(recon_min, str) or isinstance(recon_max, str):
+ if not (isinstance(recon_min, str) and isinstance(recon_max, str)):
+ raise ValueError('both recon_min & recon_max must be "auto" if one is "auto"')
+ return None, None # "auto" -> None
+ # check recon_mean and recon_std
+ elif (recon_mean is not None) or (recon_std is not None):
+ if (recon_min is not None) or (recon_max is not None):
+ raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both')
+ if (recon_mean is None) or (recon_std is None):
+ raise ValueError('both recon_mean & recon_std must be specified')
+ # set values:
+ # | ORIG: [0, 1]
+ # | TRANSFORM: (x - mean) / std -> [(0-mean)/std, (1-mean)/std]
+ # | REVERT: (x - min) / (max - min) -> [0, 1]
+ # | min=(0-mean)/std, max=(1-mean)/std
+ recon_mean, recon_std = np.array(recon_mean, dtype='float32'), np.array(recon_std, dtype='float32')
+ recon_min = np.divide(0 - recon_mean, recon_std)
+ recon_max = np.divide(1 - recon_mean, recon_std)
+ # set defaults
+ if recon_min is None: recon_min = 0.0
+ if recon_max is None: recon_max = 1.0
+ # change type
+ recon_min = np.array(recon_min)
+ recon_max = np.array(recon_max)
+ assert recon_min.ndim in (0, 1)
+ assert recon_max.ndim in (0, 1)
+ # checks
+ assert np.all(recon_min < recon_max), f'recon_min={recon_min} must be less than recon_max={recon_max}'
+ return recon_min, recon_max
+
+
+# ========================================================================= #
+# Latent Visualisation Callback #
+# ========================================================================= #
+
+
+class VaeLatentCycleLoggingCallback(BaseCallbackPeriodic):
+
+ def __init__(
+ self,
+ seed: Optional[int] = 7777,
+ every_n_steps: Optional[int] = None,
+ begin_first_step: bool = False,
+ num_frames: int = 17,
+ mode: str = 'minmax_interval_cycle',
+ num_stats_samples: int = 64,
+ log_wandb: bool = True, # TODO: detect this automatically?
+ wandb_mode: str = 'both',
+ wandb_fps: int = 4,
+ plt_show: bool = False,
+ plt_block_size: float = 1.0,
+ # recon_min & recon_max
+ recon_min: MinMaxHint = None, # scale data in this range [min, max] to [0, 1]
+ recon_max: MinMaxHint = None, # scale data in this range [min, max] to [0, 1]
+ recon_mean: MeanStdHint = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1]
+ recon_std: MeanStdHint = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1]
+ ):
+ super().__init__(every_n_steps, begin_first_step)
+ self._seed = seed
+ self._mode = mode
+ self._num_stats_samples = num_stats_samples
+ self._plt_show = plt_show
+ self._plt_block_size = plt_block_size
+ self._log_wandb = log_wandb
+ self._wandb_mode = wandb_mode
+ self._num_frames = num_frames
+ self._fps = wandb_fps
+ # checks
+ assert wandb_mode in {'img', 'vid', 'both'}, f'invalid wandb_mode={repr(wandb_mode)}, must be one of: ("img", "vid", "both")'
+ # normalize
+ self._recon_min, self._recon_max = get_vis_min_max(
+ recon_min=recon_min,
+ recon_max=recon_max,
+ recon_mean=recon_mean,
+ recon_std=recon_std,
+ )
+
+ @torch.no_grad()
+ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+ # exit early
+ if not (self._plt_show or self._log_wandb):
+ log.warning(f'skipping {self.__class__.__name__} neither `plt_show` or `log_wandb` is `True`!')
+ return
+
+ # feed forward and visualise everything!
+ stills, animation, image = self.get_visualisations(trainer, pl_module)
+
+ # log video -- none, img, vid, both
+ # TODO: there might be a memory leak in making the video below? Or there could be one in the actual DSPRITES dataset? memory usage seems to be very high and increase on this dataset.
+ if self._log_wandb:
+ import wandb
+ wandb_items = {}
+ if self._wandb_mode in ('img', 'both'): wandb_items[f'{self._mode}_img'] = wandb.Image(image)
+ if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self._mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'),
+ wb_log_metrics(trainer.logger, wandb_items)
+
+ # log locally
+ if self._plt_show:
+ from matplotlib import pyplot as plt
+ fig, ax = plt.subplots(1, 1, figsize=(self._plt_block_size*stills.shape[1], self._plt_block_size*stills.shape[0]))
+ ax.imshow(image)
+ ax.axis('off')
+ fig.tight_layout()
+ plt.show()
+
+ def get_visualisations(
+ self,
+ trainer_or_dataset: Union[pl.Trainer, DisentDataset],
+ pl_module: pl.LightningModule,
+ ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]]:
+ return self.generate_visualisations(
+ trainer_or_dataset,
+ pl_module,
+ seed=self._seed,
+ num_frames=self._num_frames,
+ mode=self._mode,
+ num_stats_samples=self._num_stats_samples,
+ recon_min=self._recon_min,
+ recon_max=self._recon_max,
+ recon_mean=None,
+ recon_std=None,
+ )
+
+ @classmethod
+ def generate_visualisations(
+ cls,
+ trainer_or_dataset: Union[pl.Trainer, DisentDataset],
+ pl_module: pl.LightningModule,
+ seed: Optional[int] = 7777,
+ num_frames: int = 17,
+ mode: str = 'fitted_gaussian_cycle',
+ num_stats_samples: int = 64,
+ # recon_min & recon_max
+ recon_min: MinMaxHint = None,
+ recon_max: MinMaxHint = None,
+ recon_mean: MeanStdHint = None,
+ recon_std: MeanStdHint = None,
+ ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]]:
+ # normalize
+ recon_min, recon_max = get_vis_min_max(
+ recon_min=recon_min,
+ recon_max=recon_max,
+ recon_mean=recon_mean,
+ recon_std=recon_std,
+ )
+
+ # get dataset and vae framework from trainer and module
+ dataset, vae = _get_dataset_and_ae_like(trainer_or_dataset, pl_module, unwrap_groundtruth=True)
+
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
+ # generate traversal
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
+
+ # get random sample of z_means and z_logvars for computing the range of values for the latent_cycle
+ with TempNumpySeed(seed):
+ batch, indices = dataset.dataset_sample_batch(num_stats_samples, mode='input', replace=True, return_indices=True) # replace just in case the dataset it tiny
+ batch = batch.to(vae.device)
+
+ # get representations
+ if isinstance(vae, Vae):
+ # variational auto-encoder
+ ds_posterior, ds_prior = vae.encode_dists(batch)
+ zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance)
+ elif isinstance(vae, Ae):
+ # auto-encoder
+ zs_mean = vae.encode(batch)
+ zs_logvar = torch.ones_like(zs_mean)
+ else:
+ log.warning(f'cannot run {cls.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
+ return
+
+ # get min and max if auto
+ if (recon_min is None) or (recon_max is None):
+ if recon_min is None: recon_min = float(torch.amin(batch).cpu())
+ if recon_max is None: recon_max = float(torch.amax(batch).cpu())
+ log.info(f'auto visualisation min: {recon_min} and max: {recon_max} obtained from {len(batch)} samples')
+
+ # produce latent cycle still images & convert them to images
+ stills = make_decoded_latent_cycles(vae.decode, zs_mean, zs_logvar, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=vae.device)[0]
+ stills = torch_to_images(stills, in_dims='CHW', out_dims='HWC', in_min=recon_min, in_max=recon_max, always_rgb=True, to_numpy=True)
+
+ # generate the video frames and image grid from the stills
+ # - TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves
+ # - TODO: should this not use `visualize_dataset_traversal`?
+ frames = make_animated_image_grid(stills, pad=4, border=True, bg_color=None)
+ image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4, border=True, bg_color=None)
+
+ # done
+ return stills, frames, image
+
+
+
+# TODO: FIX THIS ERROR:
+# [2022-03-12 11:44:59,661][experiment.util.run_utils][ERROR] - exiting: experiment error | [Errno 12] Cannot allocate memory
+# Traceback (most recent call last):
+# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 462, in hydra_main
+# run_action(cfg)
+# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 378, in run_action
+# action(cfg)
+# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 351, in action_train
+# trainer.fit(framework, datamodule=datamodule)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
+# self._call_and_handle_interrupt(
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
+# return trainer_fn(*args, **kwargs)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
+# self._run(model, ckpt_path=ckpt_path)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
+# self._dispatch()
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
+# self.training_type_plugin.start_training(self)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
+# self._results = trainer.run_stage()
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
+# return self._run_train()
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1312, in _run_train
+# self.fit_loop.run()
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
+# self.advance(*args, **kwargs)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
+# self.epoch_loop.run(data_fetcher)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
+# self.advance(*args, **kwargs)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 221, in advance
+# self.trainer.call_hook("on_batch_end")
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1477, in call_hook
+# callback_fx(*args, **kwargs)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 163, in on_batch_end
+# callback.on_batch_end(self, self.lightning_module)
+# File "/home-mscluster/nmichlo/workspace/research/disent/disent/util/lightning/callbacks/_callbacks_base.py", line 57, in on_batch_end
+# self.do_step(trainer, pl_module)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
+# return func(*args, **kwargs)
+# File "/home-mscluster/nmichlo/workspace/research/disent/disent/util/lightning/callbacks/_callback_vis_latents.py", line 165, in do_step
+# if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self._mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'),
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/wandb/sdk/data_types.py", line 1232, in __init__
+# self.encode()
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/wandb/sdk/data_types.py", line 1255, in encode
+# clip.write_videofile(filename, **kwargs)
+# File "", line 2, in write_videofile
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 54, in requires_duration
+# return f(clip, *a, **k)
+# File "", line 2, in write_videofile
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 135, in use_clip_fps_by_default
+# return f(clip, *new_a, **new_kw)
+# File "", line 2, in write_videofile
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 22, in convert_masks_to_RGB
+# return f(clip, *a, **k)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/VideoClip.py", line 300, in write_videofile
+# ffmpeg_write_video(self, filename, fps, codec,
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/io/ffmpeg_writer.py", line 213, in ffmpeg_write_video
+# with FFMPEG_VideoWriter(filename, clip.size, fps, codec = codec,
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/io/ffmpeg_writer.py", line 129, in __init__
+# self.proc = sp.Popen(cmd, **popen_params)
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/subprocess.py", line 951, in __init__
+# self._execute_child(args, executable, preexec_fn, close_fds,
+# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/subprocess.py", line 1754, in _execute_child
+# self.pid = _posixsubprocess.fork_exec(
+# OSError: [Errno 12] Cannot allocate memory
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/lightning/callbacks/_callbacks_base.py b/disent/util/lightning/callbacks/_callbacks_base.py
index db380ec1..3385b8f7 100644
--- a/disent/util/lightning/callbacks/_callbacks_base.py
+++ b/disent/util/lightning/callbacks/_callbacks_base.py
@@ -24,6 +24,8 @@
import logging
import time
+from typing import Optional
+
import pytorch_lightning as pl
@@ -37,7 +39,8 @@
class BaseCallbackPeriodic(pl.Callback):
- def __init__(self, every_n_steps=None, begin_first_step=False):
+ def __init__(self, every_n_steps: Optional[int] = None, begin_first_step: bool = False):
+ assert (every_n_steps is None) or (isinstance(every_n_steps, int) and every_n_steps > 0), f'`every_n_steps` must be None or an integer greater than zero, got: {repr(every_n_steps)}'
self.every_n_steps = every_n_steps
self.begin_first_step = begin_first_step
@@ -59,7 +62,8 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
class BaseCallbackTimed(pl.Callback):
- def __init__(self, interval=10):
+ def __init__(self, interval: float = 10):
+ assert interval > 0, f'The interval must be greater than zero, got: {repr(interval)}'
self._last_time = 0
self._interval = interval
self._start_time = time.time()
diff --git a/disent/util/lightning/callbacks/_callbacks_vae.py b/disent/util/lightning/callbacks/_callbacks_vae.py
deleted file mode 100644
index e846b86d..00000000
--- a/disent/util/lightning/callbacks/_callbacks_vae.py
+++ /dev/null
@@ -1,639 +0,0 @@
-# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-# MIT License
-#
-# Copyright (c) 2021 Nathan Juraj Michlo
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-
-import logging
-import warnings
-from typing import Callable
-from typing import List
-from typing import Literal
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import pytorch_lightning as pl
-import torch
-
-from pytorch_lightning.trainer.supporters import CombinedLoader
-from torch.utils.data.dataloader import default_collate
-from tqdm import tqdm
-
-import disent.metrics
-import disent.util.strings.colors as c
-from disent.dataset import DisentDataset
-from disent.dataset.data import GroundTruthData
-from disent.frameworks.ae import Ae
-from disent.frameworks.helper.reconstructions import make_reconstruction_loss
-from disent.frameworks.helper.reconstructions import ReconLossHandler
-from disent.frameworks.vae import Vae
-from disent.util.function import wrapped_partial
-from disent.util.iters import chunked
-from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic
-from disent.util.lightning.logger_util import log_metrics
-from disent.util.lightning.logger_util import wb_log_metrics
-from disent.util.lightning.logger_util import wb_log_reduced_summaries
-from disent.util.profiling import Timer
-from disent.util.seeds import TempNumpySeed
-from disent.util.visualize.plot import plt_subplots_imshow
-from disent.util.visualize.vis_model import latent_cycle_grid_animation
-from disent.util.visualize.vis_util import make_image_grid
-
-# TODO: wandb and matplotlib are not in requirements
-import matplotlib.pyplot as plt
-import wandb
-
-
-log = logging.getLogger(__name__)
-
-
-# ========================================================================= #
-# Helper Functions #
-# ========================================================================= #
-
-
-def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule, unwrap_groundtruth: bool = False) -> (DisentDataset, Ae):
- # TODO: improve handling!
- assert isinstance(pl_module, Ae), f'{pl_module.__class__} is not an instance of {Ae}'
- # get dataset
- if hasattr(trainer, 'datamodule') and (trainer.datamodule is not None):
- assert hasattr(trainer.datamodule, 'dataset_train_noaug') # TODO: this is for experiments, another way of handling this should be added
- dataset = trainer.datamodule.dataset_train_noaug
- elif hasattr(trainer, 'train_dataloader') and (trainer.train_dataloader is not None):
- if isinstance(trainer.train_dataloader, CombinedLoader):
- dataset = trainer.train_dataloader.loaders.dataset
- else:
- raise RuntimeError(f'invalid trainer.train_dataloader: {trainer.train_dataloader}')
- else:
- raise RuntimeError('could not retrieve dataset! please report this...')
- # check dataset
- assert isinstance(dataset, DisentDataset), f'retrieved dataset is not an {DisentDataset.__name__}'
- # unwarp dataset
- if unwrap_groundtruth:
- if dataset.is_wrapped_gt_data:
- old_dataset, dataset = dataset, dataset.unwrapped_disent_dataset()
- warnings.warn(f'Unwrapped ground truth dataset returned! {type(old_dataset.data).__name__} -> {type(dataset.data).__name__}')
- # done checks
- return dataset, pl_module
-
-
-# ========================================================================= #
-# Vae Framework Callbacks #
-# ========================================================================= #
-
-# helper
-def _to_dmat(
- size: int,
- i_a: np.ndarray,
- i_b: np.ndarray,
- dists: Union[torch.Tensor, np.ndarray],
-) -> np.ndarray:
- if isinstance(dists, torch.Tensor):
- dists = dists.detach().cpu().numpy()
- # checks
- assert i_a.ndim == 1
- assert i_a.shape == i_b.shape
- assert i_a.shape == dists.shape
- # compute
- dmat = np.zeros([size, size], dtype='float32')
- dmat[i_a, i_b] = dists
- dmat[i_b, i_a] = dists
- return dmat
-
-
-_AE_DIST_NAMES = ('x', 'z', 'x_recon')
-_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon')
-
-
-@torch.no_grad()
-def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor):
- # feed forware
- z_a, z_b = ae.encode(x_a), ae.encode(x_b)
- r_a, r_b = ae.decode(z_a), ae.decode(z_b)
- # distances
- return [
- ae.recon_handler.compute_pairwise_loss(x_a, x_b),
- torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
- ae.recon_handler.compute_pairwise_loss(r_a, r_b),
- ]
-
-
-@torch.no_grad()
-def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor):
- from torch.distributions import kl_divergence
- # feed forward
- (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b)
- z_a, z_b = z_post_a.mean, z_post_b.mean
- r_a, r_b = vae.decode(z_a), vae.decode(z_b)
- # dists
- kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a)
- # distances
- return [
- vae.recon_handler.compute_pairwise_loss(x_a, x_b),
- torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
- vae.recon_handler._pairwise_reduce(kl_ab),
- vae.recon_handler.compute_pairwise_loss(r_a, r_b),
- ]
-
-
-def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]:
- # get aggregate function
- if isinstance(model, Vae):
- dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model)
- elif isinstance(model, Ae):
- dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model)
- else:
- dists_names, dists_fn = None, None
- return dists_names, dists_fn
-
-
-@torch.no_grad()
-def _collect_dists_subbatches(dists_fn: Callable[[object, object], Sequence[Sequence[float]]], batch: torch.Tensor, i_a: np.ndarray, i_b: np.ndarray, batch_size: int = 64):
- # feed forward
- results = []
- for idxs in chunked(np.stack([i_a, i_b], axis=-1), chunk_size=batch_size):
- ia, ib = idxs.T
- x_a, x_b = batch[ia], batch[ib]
- # feed forward
- data = dists_fn(x_a, x_b)
- results.append(data)
- return [torch.cat(r, dim=0) for r in zip(*results)]
-
-
-def _compute_and_collect_dists(
- dataset: DisentDataset,
- dists_fn,
- dists_names: Sequence[str],
- traversal_repeats: int = 100,
- batch_size: int = 32,
- include_gt_factor_dists: bool = True,
- transform_batch: Callable[[object], object] = None,
- data_mode: str = 'input',
-) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]:
- assert traversal_repeats > 0
- gt_data = dataset.gt_data
- # generate
- f_grid = []
- # generate
- for f_idx, f_size in enumerate(gt_data.factor_sizes):
- # save for the current factor (traversal_repeats, len(names), len(i_a))
- f_dists = []
- # upper triangle excluding diagonal
- i_a, i_b = np.triu_indices(f_size, k=1)
- # repeat over random traversals
- for i in range(traversal_repeats):
- # get random factor traversal
- factors = gt_data.sample_random_factor_traversal(f_idx=f_idx)
- indices = gt_data.pos_to_idx(factors)
- # load data
- batch = dataset.dataset_batch_from_indices(indices, data_mode)
- if transform_batch is not None:
- batch = transform_batch(batch)
- # feed forward & compute dists -- (len(names), len(i_a))
- dists = _collect_dists_subbatches(dists_fn=dists_fn, batch=batch, i_a=i_a, i_b=i_b, batch_size=batch_size)
- assert len(dists) == len(dists_names)
- # distances
- f_dists.append(dists)
- # aggregate all dists into distances matrices for current factor
- f_dmats = [
- _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=torch.stack(dists, dim=0).mean(dim=0))
- for dists in zip(*f_dists)
- ]
- # handle factors
- if include_gt_factor_dists:
- i_dmat = _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=np.abs(factors[i_a] - factors[i_b]).sum(axis=-1))
- f_dmats = [i_dmat, *f_dmats]
- # append data
- f_grid.append(f_dmats)
- # handle factors
- if include_gt_factor_dists:
- dists_names = ('factors', *dists_names)
- # done
- return tuple(dists_names), f_grid
-
-
-def compute_factor_distances(
- dataset: DisentDataset,
- dists_fn,
- dists_names: Sequence[str],
- traversal_repeats: int = 100,
- batch_size: int = 32,
- include_gt_factor_dists: bool = True,
- transform_batch: Callable[[object], object] = None,
- seed: Optional[int] = 777,
- data_mode: str = 'input',
-) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]:
- # log this callback
- gt_data = dataset.gt_data
- log.info(f'| {gt_data.name} - computing factor distances...')
- # compute various distances matrices for each factor
- with Timer() as timer, TempNumpySeed(seed):
- dists_names, f_grid = _compute_and_collect_dists(
- dataset=dataset,
- dists_fn=dists_fn,
- dists_names=dists_names,
- traversal_repeats=traversal_repeats,
- batch_size=batch_size,
- include_gt_factor_dists=include_gt_factor_dists,
- transform_batch=transform_batch,
- data_mode=data_mode,
- )
- # log this callback!
- log.info(f'| {gt_data.name} - computed factor distances! time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST}')
- return dists_names, f_grid
-
-
-def plt_factor_distances(
- gt_data: GroundTruthData,
- f_grid: List[List[np.ndarray]],
- dists_names: Sequence[str],
- title: str,
- plt_block_size: float = 1.25,
- plt_transpose: bool = False,
- plt_cmap='Blues',
-):
- # plot information
- imshow_kwargs = dict(cmap=plt_cmap)
- figsize = (plt_block_size*len(f_grid[0]), plt_block_size * gt_data.num_factors)
- # plot!
- if not plt_transpose:
- fig, axs = plt_subplots_imshow(grid=f_grid, col_labels=dists_names, row_labels=gt_data.factor_names, figsize=figsize, title=title, imshow_kwargs=imshow_kwargs)
- else:
- fig, axs = plt_subplots_imshow(grid=list(zip(*f_grid)), col_labels=gt_data.factor_names, row_labels=dists_names, figsize=figsize[::-1], title=title, imshow_kwargs=imshow_kwargs)
- # done
- return fig, axs
-
-
-class VaeGtDistsLoggingCallback(BaseCallbackPeriodic):
-
- def __init__(
- self,
- seed: Optional[int] = 7777,
- every_n_steps: Optional[int] = None,
- traversal_repeats: int = 100,
- begin_first_step: bool = False,
- plt_block_size: float = 1.25,
- plt_show: bool = False,
- plt_transpose: bool = False,
- log_wandb: bool = True,
- batch_size: int = 128,
- include_factor_dists: bool = True,
- ):
- assert traversal_repeats > 0
- self._traversal_repeats = traversal_repeats
- self._seed = seed
- self._plt_block_size = plt_block_size
- self._plt_show = plt_show
- self._log_wandb = log_wandb
- self._include_gt_factor_dists = include_factor_dists
- self._transpose_plot = plt_transpose
- self._batch_size = batch_size
- super().__init__(every_n_steps, begin_first_step)
-
- @torch.no_grad()
- def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
- # get dataset and vae framework from trainer and module
- dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True)
- # exit early
- if not dataset.is_ground_truth:
- log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!')
- return
- # get aggregate function
- dists_names, dists_fn = _get_dists_fn(vae)
- if (dists_names is None) or (dists_fn is None):
- log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
- return
- # compute various distances matrices for each factor
- dists_names, f_grid = compute_factor_distances(
- dataset=dataset,
- dists_fn=dists_fn,
- dists_names=dists_names,
- traversal_repeats=self._traversal_repeats,
- batch_size=self._batch_size,
- include_gt_factor_dists=self._include_gt_factor_dists,
- transform_batch=lambda batch: batch.to(vae.device),
- seed=self._seed,
- data_mode='input',
- )
- # plot these results
- fig, axs = plt_factor_distances(
- gt_data=dataset.gt_data,
- f_grid=f_grid,
- dists_names=dists_names,
- title=f'{vae.__class__.__name__}: {dataset.gt_data.name.capitalize()} Distances',
- plt_block_size=self._plt_block_size,
- plt_transpose=self._transpose_plot,
- plt_cmap='Blues',
- )
- # show the plot
- if self._plt_show:
- plt.show()
- # log the plot to wandb
- if self._log_wandb:
- wb_log_metrics(trainer.logger, {
- 'factor_distances': wandb.Image(fig)
- })
-
-
-def _normalize_min_max_mean_std_to_min_max(recon_min, recon_max, recon_mean, recon_std) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
- # check recon_min and recon_max
- if (recon_min is not None) or (recon_max is not None):
- if (recon_mean is not None) or (recon_std is not None):
- raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both')
- if (recon_min is None) or (recon_max is None):
- raise ValueError('both recon_min & recon_max must be specified')
- # check strings
- if isinstance(recon_min, str) or isinstance(recon_max, str):
- if not (isinstance(recon_min, str) and isinstance(recon_max, str)):
- raise ValueError('both recon_min & recon_max must be "auto" if one is "auto"')
- return None, None
- # check recon_mean and recon_std
- elif (recon_mean is not None) or (recon_std is not None):
- if (recon_min is not None) or (recon_max is not None):
- raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both')
- if (recon_mean is None) or (recon_std is None):
- raise ValueError('both recon_mean & recon_std must be specified')
- # set values:
- # | ORIG: [0, 1]
- # | TRANSFORM: (x - mean) / std -> [(0-mean)/std, (1-mean)/std]
- # | REVERT: (x - min) / (max - min) -> [0, 1]
- # | min=(0-mean)/std, max=(1-mean)/std
- recon_mean, recon_std = np.array(recon_mean, dtype='float32'), np.array(recon_std, dtype='float32')
- recon_min = np.divide(0 - recon_mean, recon_std)
- recon_max = np.divide(1 - recon_mean, recon_std)
- # set defaults
- if recon_min is None: recon_min = 0.0
- if recon_max is None: recon_max = 0.0
- # change type
- recon_min = np.array(recon_min)
- recon_max = np.array(recon_max)
- assert recon_min.ndim in (0, 1)
- assert recon_max.ndim in (0, 1)
- # checks
- assert np.all(recon_min < np.all(recon_max)), f'recon_min={recon_min} must be less than recon_max={recon_max}'
- return recon_min, recon_max
-
-
-class VaeLatentCycleLoggingCallback(BaseCallbackPeriodic):
-
- def __init__(
- self,
- seed: Optional[int] = 7777,
- every_n_steps: Optional[int] = None,
- begin_first_step: bool = False,
- num_frames: int = 17,
- mode: str = 'fitted_gaussian_cycle',
- wandb_mode: str = 'both',
- wandb_fps: int = 4,
- plt_show: bool = False,
- plt_block_size: float = 1.0,
- # recon_min & recon_max
- recon_min: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1]
- recon_max: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1]
- recon_mean: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1]
- recon_std: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1]
- ):
- super().__init__(every_n_steps, begin_first_step)
- self.seed = seed
- self.mode = mode
- self.plt_show = plt_show
- self.plt_block_size = plt_block_size
- self._wandb_mode = wandb_mode
- self._num_frames = num_frames
- self._fps = wandb_fps
- # checks
- assert wandb_mode in {'none', 'img', 'vid', 'both'}, f'invalid wandb_mode={repr(wandb_mode)}, must be one of: ("none", "img", "vid", "both")'
- # normalize
- self._recon_min, self._recon_max = _normalize_min_max_mean_std_to_min_max(
- recon_min=recon_min,
- recon_max=recon_max,
- recon_mean=recon_mean,
- recon_std=recon_std,
- )
-
-
- def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
- # get dataset and vae framework from trainer and module
- dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True)
-
- # TODO: should this not use `visualize_dataset_traversal`?
-
- with torch.no_grad():
- # get random sample of z_means and z_logvars for computing the range of values for the latent_cycle
- with TempNumpySeed(self.seed):
- batch = dataset.dataset_sample_batch(64, mode='input').to(vae.device)
-
- # get representations
- if isinstance(vae, Vae):
- # variational auto-encoder
- ds_posterior, ds_prior = vae.encode_dists(batch)
- zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance)
- elif isinstance(vae, Ae):
- # auto-encoder
- zs_mean = vae.encode(batch)
- zs_logvar = torch.ones_like(zs_mean)
- else:
- log.warning(f'cannot run {self.__class__.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
- return
-
- # get min and max if auto
- if (self._recon_min is None) or (self._recon_max is None):
- if self._recon_min is None: self._recon_min = float(torch.min(batch).cpu())
- if self._recon_max is None: self._recon_max = float(torch.max(batch).cpu())
- log.info(f'auto visualisation min: {self._recon_min} and max: {self._recon_max} obtained from {len(batch)} samples')
-
- # produce latent cycle grid animation
- # TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves
- animation, stills = latent_cycle_grid_animation(
- vae.decode, zs_mean, zs_logvar,
- mode=self.mode, num_frames=self._num_frames, decoder_device=vae.device, tensor_style_channels=False, return_stills=True,
- to_uint8=True, recon_min=self._recon_min, recon_max=self._recon_max,
- )
- image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4)
-
- # log video -- none, img, vid, both
- wandb_items = {}
- if self._wandb_mode in ('img', 'both'): wandb_items[f'{self.mode}_img'] = wandb.Image(image)
- if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self.mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'),
- wb_log_metrics(trainer.logger, wandb_items)
-
- # log locally
- if self.plt_show:
- fig, ax = plt.subplots(1, 1, figsize=(self.plt_block_size*stills.shape[1], self.plt_block_size*stills.shape[0]))
- ax.imshow(image)
- ax.axis('off')
- fig.tight_layout()
- plt.show()
-
-
-def _normalized_numeric_metrics(items: dict):
- results = {}
- for k, v in items.items():
- if isinstance(v, (float, int)):
- results[k] = v
- else:
- try:
- results[k] = float(v)
- except:
- log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}')
- return results
-
-
-class VaeMetricLoggingCallback(BaseCallbackPeriodic):
-
- def __init__(
- self,
- step_end_metrics: Optional[Sequence[str]] = None,
- train_end_metrics: Optional[Sequence[str]] = None,
- every_n_steps: Optional[int] = None,
- begin_first_step: bool = False,
- ):
- super().__init__(every_n_steps, begin_first_step)
- self.step_end_metrics = step_end_metrics if step_end_metrics else []
- self.train_end_metrics = train_end_metrics if train_end_metrics else []
- assert isinstance(self.step_end_metrics, list)
- assert isinstance(self.train_end_metrics, list)
- assert self.step_end_metrics or self.train_end_metrics, 'No metrics given to step_end_metrics or train_end_metrics'
-
- def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningModule, metrics: list, is_final=False):
- # get dataset and vae framework from trainer and module
- dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True)
- # check if we need to skip
- # TODO: dataset needs to be able to handle wrapped datasets!
- if not dataset.is_ground_truth:
- warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!')
- return
- # compute all metrics
- for metric in metrics:
- pad = max(7+len(k) for k in disent.metrics.DEFAULT_METRICS) # I know this is a magic variable... im just OCD
- if is_final:
- log.info(f'| {metric.__name__:<{pad}} - computing...')
- with Timer() as timer:
- scores = metric(dataset, lambda x: vae.encode(x.to(vae.device)))
- metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items())
- log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}')
-
- # log to trainer
- prefix = 'final_metric' if is_final else 'epoch_metric'
- prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()}
- log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores))
-
- # log summary for WANDB
- # this is kinda hacky... the above should work for parallel coordinate plots
- wb_log_reduced_summaries(trainer.logger, prefixed_scores, reduction='max')
-
- def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
- if self.step_end_metrics:
- log.debug('Computing Epoch Metrics:')
- with Timer() as timer:
- self._compute_metrics_and_log(trainer, pl_module, metrics=self.step_end_metrics, is_final=False)
- log.debug(f'Computed Epoch Metrics! {timer.pretty}')
-
- def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
- if self.train_end_metrics:
- log.debug('Computing Final Metrics...')
- with Timer() as timer:
- self._compute_metrics_and_log(trainer, pl_module, metrics=self.train_end_metrics, is_final=True)
- log.debug(f'Computed Final Metrics! {timer.pretty}')
-
-
-# class VaeLatentCorrelationLoggingCallback(BaseCallbackPeriodic):
-#
-# def __init__(self, repeats_per_factor=10, every_n_steps=None, begin_first_step=False):
-# super().__init__(every_n_steps=every_n_steps, begin_first_step=begin_first_step)
-# self._repeats_per_factor = repeats_per_factor
-#
-# def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
-# # get dataset and vae framework from trainer and module
-# dataset, vae = _get_dataset_and_vae(trainer, pl_module)
-# # check if we need to skip
-# if not dataset.is_ground_truth:
-# warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!')
-# return
-# # TODO: CONVERT THIS TO A METRIC!
-# # log the correspondence between factors and the latent space.
-# num_samples = np.sum(dataset.ground_truth_data.factor_sizes) * self._repeats_per_factor
-# factors = dataset.ground_truth_data.sample_factors(num_samples)
-# # encode observations of factors
-# zs = np.concatenate([
-# to_numpy(vae.encode(dataset.dataset_batch_from_factors(factor_batch, mode='input').to(vae.device)))
-# for factor_batch in iter_chunks(factors, 256)
-# ])
-# z_size = zs.shape[-1]
-#
-# # calculate correlation matrix
-# f_and_z = np.concatenate([factors.T, zs.T])
-# f_and_z_corr = np.corrcoef(f_and_z)
-# # get correlation submatricies
-# f_corr = f_and_z_corr[:z_size, :z_size] # upper left
-# z_corr = f_and_z_corr[z_size:, z_size:] # bottom right
-# fz_corr = f_and_z_corr[z_size:, :z_size] # upper right | y: z, x: f
-# # get maximum z correlations per factor
-# z_to_f_corr_maxs = np.max(np.abs(fz_corr), axis=0)
-# f_to_z_corr_maxs = np.max(np.abs(fz_corr), axis=1)
-# assert len(z_to_f_corr_maxs) == z_size
-# assert len(f_to_z_corr_maxs) == dataset.ground_truth_data.num_factors
-# # average correlation
-# ave_f_to_z_corr = f_to_z_corr_maxs.mean()
-# ave_z_to_f_corr = z_to_f_corr_maxs.mean()
-#
-# # print
-# log.info(f'ave latent correlation: {ave_z_to_f_corr}')
-# log.info(f'ave factor correlation: {ave_f_to_z_corr}')
-# # log everything
-# log_metrics(trainer.logger, {
-# 'metric.ave_latent_correlation': ave_z_to_f_corr,
-# 'metric.ave_factor_correlation': ave_f_to_z_corr,
-# })
-# # make sure we only log the heatmap to WandB
-# wb_log_metrics(trainer.logger, {
-# 'metric.correlation_heatmap': wandb.plots.HeatMap(
-# x_labels=[f'z{i}' for i in range(z_size)],
-# y_labels=list(dataset.ground_truth_data.factor_names),
-# matrix_values=fz_corr, show_text=False
-# ),
-# })
-#
-# NUM = 1
-# # generate traversal value graphs
-# for i in range(z_size):
-# correlation = np.abs(f_corr[i, :])
-# correlation[i] = 0
-# for j in np.argsort(correlation)[::-1][:NUM]:
-# if i == j:
-# continue
-# ix, iy = (i, j) # if i < j else (j, i)
-# plt.scatter(zs[:, ix], zs[:, iy])
-# plt.title(f'z{ix}-vs-z{iy}')
-# plt.xlabel(f'z{ix}')
-# plt.ylabel(f'z{iy}')
-#
-# # wandb.log({f"chart.correlation.z{ix}-vs-z{iy}": plt})
-# # make sure we only log to WANDB
-# wb_log_metrics(trainer.logger, {f"chart.correlation.z{ix}-vs-max-corr": plt})
-
-
-# ========================================================================= #
-# END #
-# ========================================================================= #
diff --git a/disent/util/lightning/callbacks/_helper.py b/disent/util/lightning/callbacks/_helper.py
new file mode 100644
index 00000000..6d9883af
--- /dev/null
+++ b/disent/util/lightning/callbacks/_helper.py
@@ -0,0 +1,152 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import warnings
+from typing import Union
+
+import pytorch_lightning as pl
+from pytorch_lightning.trainer.supporters import CombinedLoader
+
+from disent.dataset import DisentDataset
+from disent.frameworks.ae import Ae
+from disent.frameworks.vae import Vae
+
+
+# ========================================================================= #
+# Helper Functions #
+# ========================================================================= #
+
+
+def _get_dataset_and_ae_like(trainer_or_dataset: pl.Trainer, pl_module: pl.LightningModule, unwrap_groundtruth: bool = False) -> (DisentDataset, Union[Ae, Vae]):
+ assert isinstance(pl_module, (Ae, Vae)), f'{pl_module.__class__} is not an instance of {Ae} or {Vae}'
+ # get dataset
+ if isinstance(trainer_or_dataset, pl.Trainer):
+ trainer = trainer_or_dataset
+ if hasattr(trainer, 'datamodule') and (trainer.datamodule is not None):
+ assert hasattr(trainer.datamodule, 'dataset_train_noaug') # TODO: this is for experiments, another way of handling this should be added
+ dataset = trainer.datamodule.dataset_train_noaug
+ elif hasattr(trainer, 'train_dataloader') and (trainer.train_dataloader is not None):
+ if isinstance(trainer.train_dataloader, CombinedLoader):
+ dataset = trainer.train_dataloader.loaders.dataset
+ else:
+ raise RuntimeError(f'invalid trainer.train_dataloader: {trainer.train_dataloader}')
+ else:
+ raise RuntimeError('could not retrieve dataset! please report this...')
+ else:
+ dataset = trainer_or_dataset
+ # check dataset
+ assert isinstance(dataset, DisentDataset), f'retrieved dataset is not an {DisentDataset.__name__}'
+ # unwrap dataset
+ if unwrap_groundtruth:
+ if dataset.is_wrapped_gt_data:
+ old_dataset, dataset = dataset, dataset.unwrapped_shallow_copy()
+ warnings.warn(f'Unwrapped ground truth dataset returned! {type(old_dataset.data).__name__} -> {type(dataset.data).__name__}')
+ # done checks
+ return dataset, pl_module
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
+
+
+# class VaeLatentCorrelationLoggingCallback(BaseCallbackPeriodic):
+#
+# def __init__(self, repeats_per_factor=10, every_n_steps=None, begin_first_step=False):
+# super().__init__(every_n_steps=every_n_steps, begin_first_step=begin_first_step)
+# self._repeats_per_factor = repeats_per_factor
+#
+# def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+# # get dataset and vae framework from trainer and module
+# dataset, vae = _get_dataset_and_vae(trainer, pl_module)
+# # check if we need to skip
+# if not dataset.is_ground_truth:
+# warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!')
+# return
+# # TODO: CONVERT THIS TO A METRIC!
+# # log the correspondence between factors and the latent space.
+# num_samples = np.sum(dataset.ground_truth_data.factor_sizes) * self._repeats_per_factor
+# factors = dataset.ground_truth_data.sample_factors(num_samples)
+# # encode observations of factors
+# zs = np.concatenate([
+# to_numpy(vae.encode(dataset.dataset_batch_from_factors(factor_batch, mode='input').to(vae.device)))
+# for factor_batch in iter_chunks(factors, 256)
+# ])
+# z_size = zs.shape[-1]
+#
+# # calculate correlation matrix
+# f_and_z = np.concatenate([factors.T, zs.T])
+# f_and_z_corr = np.corrcoef(f_and_z)
+# # get correlation submatricies
+# f_corr = f_and_z_corr[:z_size, :z_size] # upper left
+# z_corr = f_and_z_corr[z_size:, z_size:] # bottom right
+# fz_corr = f_and_z_corr[z_size:, :z_size] # upper right | y: z, x: f
+# # get maximum z correlations per factor
+# z_to_f_corr_maxs = np.max(np.abs(fz_corr), axis=0)
+# f_to_z_corr_maxs = np.max(np.abs(fz_corr), axis=1)
+# assert len(z_to_f_corr_maxs) == z_size
+# assert len(f_to_z_corr_maxs) == dataset.ground_truth_data.num_factors
+# # average correlation
+# ave_f_to_z_corr = f_to_z_corr_maxs.mean()
+# ave_z_to_f_corr = z_to_f_corr_maxs.mean()
+#
+# # print
+# log.info(f'ave latent correlation: {ave_z_to_f_corr}')
+# log.info(f'ave factor correlation: {ave_f_to_z_corr}')
+# # log everything
+# log_metrics(trainer.logger, {
+# 'metric.ave_latent_correlation': ave_z_to_f_corr,
+# 'metric.ave_factor_correlation': ave_f_to_z_corr,
+# })
+# # make sure we only log the heatmap to WandB
+# wb_log_metrics(trainer.logger, {
+# 'metric.correlation_heatmap': wandb.plots.HeatMap(
+# x_labels=[f'z{i}' for i in range(z_size)],
+# y_labels=list(dataset.ground_truth_data.factor_names),
+# matrix_values=fz_corr, show_text=False
+# ),
+# })
+#
+# NUM = 1
+# # generate traversal value graphs
+# for i in range(z_size):
+# correlation = np.abs(f_corr[i, :])
+# correlation[i] = 0
+# for j in np.argsort(correlation)[::-1][:NUM]:
+# if i == j:
+# continue
+# ix, iy = (i, j) # if i < j else (j, i)
+# plt.scatter(zs[:, ix], zs[:, iy])
+# plt.title(f'z{ix}-vs-z{iy}')
+# plt.xlabel(f'z{ix}')
+# plt.ylabel(f'z{iy}')
+#
+# # wandb.log({f"chart.correlation.z{ix}-vs-z{iy}": plt})
+# # make sure we only log to WANDB
+# wb_log_metrics(trainer.logger, {f"chart.correlation.z{ix}-vs-max-corr": plt})
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/lightning/logger_util.py b/disent/util/lightning/logger_util.py
index 0daec37e..c99fcce1 100644
--- a/disent/util/lightning/logger_util.py
+++ b/disent/util/lightning/logger_util.py
@@ -34,6 +34,11 @@
log = logging.getLogger(__name__)
+# TODO: convert these functions into generalised log_img, log_vid,
+# log_number, log_text functions that support the different backends.
+# - wandb
+# - disk backend
+# - comet ml
# ========================================================================= #
# Logger Utils #
diff --git a/disent/util/math/dither.py b/disent/util/math/dither.py
index 4ae2d524..20229cb1 100644
--- a/disent/util/math/dither.py
+++ b/disent/util/math/dither.py
@@ -176,6 +176,7 @@ def _is_power_2(num: int):
@functools.lru_cache()
def _normalize_axis(ndim: int, axis: Optional[Sequence[int]]) -> np.ndarray:
# TODO: this functionality may be duplicated
+ # -- similar to np.normalize_axis_tuple(...)
# defaults
if axis is None:
axis = np.arange(ndim)
diff --git a/disent/util/math/integer.py b/disent/util/math/integer.py
new file mode 100644
index 00000000..eee92707
--- /dev/null
+++ b/disent/util/math/integer.py
@@ -0,0 +1,53 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+
+# ========================================================================= #
+# Working With Arbitrary Precision Integers #
+# ========================================================================= #
+
+
+def gcd(a: int, b: int) -> int:
+ """
+ Compute the greatest common divisor of a and b
+ TODO: not actually sure if this returns the correct values for zero or negative inputs?
+ """
+ assert isinstance(a, int), f'number must be an int, got: {type(a)}'
+ assert isinstance(b, int), f'number must be an int, got: {type(b)}'
+ while b > 0:
+ a, b = b, a % b
+ return a
+
+
+def lcm(a: int, b: int) -> int:
+ """
+ Compute the lowest common multiple of a and b
+ TODO: not actually sure if this returns the correct values for zero or negative inputs?
+ """
+ return (a * b) // gcd(a, b)
+
+
+# ========================================================================= #
+# End #
+# ========================================================================= #
diff --git a/disent/util/visualize/plot.py b/disent/util/visualize/plot.py
index 826594c2..6c9d3de4 100644
--- a/disent/util/visualize/plot.py
+++ b/disent/util/visualize/plot.py
@@ -44,6 +44,20 @@
log = logging.getLogger(__name__)
+# ========================================================================= #
+# vars #
+# ========================================================================= #
+
+
+_TORCH_NORMAL_TYPES = {torch.float16, torch.float32, torch.float64}
+
+# torch.complex32 exists in 1.10, but was disabled in 1.11, and planned to be added in 1.12 again
+if torch.version.__version__.startswith('1.11.'):
+ _TORCH_COMPLEX_TYPES = {torch.complex64}
+else:
+ _TORCH_COMPLEX_TYPES = {torch.complex32, torch.complex64}
+
+
# ========================================================================= #
# images #
# ========================================================================= #
@@ -59,11 +73,11 @@ def to_img(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> tor
def to_imgs(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> torch.Tensor:
# (..., C, H, W)
assert x.ndim >= 3, 'image must have 3 or more dimensions: (..., C, H, W)'
- assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64}, f'unsupported dtype: {x.dtype}'
+ assert (x.dtype in _TORCH_NORMAL_TYPES) or (x.dtype in _TORCH_COMPLEX_TYPES), f'unsupported dtype: {x.dtype}'
# no gradient
with torch.no_grad():
# imaginary to real
- if x.dtype in {torch.complex32, torch.complex64}:
+ if x.dtype in _TORCH_COMPLEX_TYPES:
x = torch.abs(x)
# scale images
if scale:
@@ -310,11 +324,10 @@ def visualize_dataset_traversal(
# TODO: this is kinda hacky, maybe rather add a check?
# TODO: can this be moved into the `output_wandb` if statement?
# - animations glitch out if they do not have 3 channels
+ assert grid.ndim == 5, f'invalid number of dimensions, must be 5, got: {grid.ndim}'
if grid.shape[-1] == 1:
grid = grid.repeat(3, axis=-1)
-
- assert grid.ndim == 5
- assert grid.shape[-1] in (1, 3)
+ assert grid.shape[-1] in (1, 3), f'invalid number of channels, must be 1 or 3, got shape: {grid.shape}. Note that the dataset or augment if specified should output HWC images, not CHW images!'
# generate visuals
image = make_image_grid(np.concatenate(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=num_frames)
diff --git a/disent/util/visualize/vis_img.py b/disent/util/visualize/vis_img.py
new file mode 100644
index 00000000..874bd0d3
--- /dev/null
+++ b/disent/util/visualize/vis_img.py
@@ -0,0 +1,419 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+
+import warnings
+from functools import lru_cache
+from numbers import Number
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from numpy.core.multiarray import normalize_axis_index
+from PIL import Image
+
+
+# ========================================================================= #
+# Type Hints #
+# ========================================================================= #
+
+
+# from torch.testing._internal.common_utils import numpy_to_torch_dtype_dict
+_NP_TO_TORCH_DTYPE = {
+ np.dtype('bool'): torch.bool,
+ np.dtype('uint8'): torch.uint8,
+ np.dtype('int8'): torch.int8,
+ np.dtype('int16'): torch.int16,
+ np.dtype('int32'): torch.int32,
+ np.dtype('int64'): torch.int64,
+ np.dtype('float16'): torch.float16,
+ np.dtype('float32'): torch.float32,
+ np.dtype('float64'): torch.float64,
+ np.dtype('complex64'): torch.complex64,
+ np.dtype('complex128'): torch.complex128
+}
+
+
+MinMaxHint = Union[Number, Tuple[Number, ...], np.ndarray]
+
+
+@lru_cache()
+def _dtype_min_max(dtype: torch.dtype) -> Tuple[Union[float, int], Union[float, int]]:
+ """Get the min and max values for a dtype"""
+ dinfo = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
+ return dinfo.min, dinfo.max
+
+
+@lru_cache()
+def _check_image_dtype(dtype: torch.dtype):
+ """Check that a dtype can hold image values"""
+ # check that the datatype is within the right range -- this is not actually necessary if the below is correct!
+ dmin, dmax = _dtype_min_max(dtype)
+ imin, imax = (0, 1) if dtype.is_floating_point else (0, 255)
+ assert (dmin <= imin) and (imax <= dmax), f'The dtype: {repr(dtype)} with range [{dmin}, {dmax}] cannot store image values in the range [{imin}, {imax}]'
+ # check the datatype is allowed
+ if dtype not in _ALLOWED_DTYPES:
+ raise TypeError(f'The dtype: {repr(dtype)} is not allowed, must be one of: {list(_ALLOWED_DTYPES)}')
+ # return the min and max values
+ return imin, imax
+
+
+# ========================================================================= #
+# Image Helper Functions #
+# ========================================================================= #
+
+
+def torch_image_has_valid_range(tensor: torch.Tensor, check_mode: Optional[str] = None) -> bool:
+ """
+ Check that the range of values in the image is correct!
+ """
+ if check_mode not in {'error', 'warn', 'bool', None}:
+ raise KeyError(f'invalid check_mode: {repr(check_mode)}')
+ # get the range for the dtype
+ imin, imax = _check_image_dtype(tensor.dtype)
+ # get the values
+ m = tensor.amin().cpu().numpy()
+ M = tensor.amax().cpu().numpy()
+ if (m < imin) or (imax < M):
+ if check_mode == 'error':
+ raise ValueError(f'images value range: [{m}, {M}] is outside of the required range: [{imin}, {imax}], for dtype: {repr(tensor.dtype)}')
+ elif check_mode == 'warn':
+ warnings.warn(f'images value range: [{m}, {M}] is outside of the required range: [{imin}, {imax}], for dtype: {repr(tensor.dtype)}')
+ return False
+ return True
+
+
+@torch.no_grad()
+def torch_image_clamp(tensor: torch.Tensor, clamp_mode: str = 'warn') -> torch.Tensor:
+ """
+ Clamp the image based on the dtype
+ Valid `clamp_mode`s are {'warn', 'error', 'clamp'}
+ """
+ # check range of values
+ if clamp_mode in ('warn', 'error'):
+ torch_image_has_valid_range(tensor, check_mode=clamp_mode)
+ elif clamp_mode != 'clamp':
+ raise KeyError(f'invalid clamp mode: {repr(clamp_mode)}')
+ # get the range for the dtype
+ imin, imax = _check_image_dtype(tensor.dtype)
+ # clamp!
+ return torch.clamp(tensor, imin, imax)
+
+
+@torch.no_grad()
+def torch_image_to_dtype(tensor: torch.Tensor, out_dtype: torch.dtype):
+ """
+ Convert an image to the specified dtype
+ - Scaling is automatically performed based on the input and output dtype
+ Floats should be in the range [0, 1], integers should be in the range [0, 255]
+ - if precision will be lost (), then the values are clamped!
+ """
+ _check_image_dtype(tensor.dtype)
+ _check_image_dtype(out_dtype)
+ # check scale
+ torch_image_has_valid_range(tensor, check_mode='error')
+ # convert
+ if tensor.dtype.is_floating_point and (not out_dtype.is_floating_point):
+ # [float -> int] -- cast after scaling
+ return torch.clamp(tensor * 255, 0, 255).to(out_dtype)
+ elif (not tensor.dtype.is_floating_point) and out_dtype.is_floating_point:
+ # [int -> float] -- cast before scaling
+ return torch.clamp(tensor.to(out_dtype) / 255, 0, 1)
+ else:
+ # [int -> int] | [float -> float]
+ return tensor.to(out_dtype)
+
+
+@torch.no_grad()
+def torch_image_normalize_channels(
+ tensor: torch.Tensor,
+ in_min: MinMaxHint,
+ in_max: MinMaxHint,
+ channel_dim: int = -1,
+ out_dtype: Optional[torch.dtype] = None
+):
+ if out_dtype is None:
+ out_dtype = tensor.dtype
+ # check dtypes
+ _check_image_dtype(out_dtype)
+ assert out_dtype.is_floating_point, f'out_dtype must be a floating point, got: {repr(out_dtype)}'
+ # get norm values padded to the dimension of the channel
+ in_min, in_max = _torch_channel_broadcast_scale_values(in_min, in_max, in_dtype=tensor.dtype, dim=channel_dim, ndim=tensor.ndim)
+ # convert
+ tensor = tensor.to(out_dtype)
+ in_min = torch.as_tensor(in_min, dtype=tensor.dtype, device=tensor.device)
+ in_max = torch.as_tensor(in_max, dtype=tensor.dtype, device=tensor.device)
+ # warn if the values are the same
+ if torch.any(in_min == in_max):
+ m = in_min.cpu().detach().numpy()
+ M = in_min.cpu().detach().numpy()
+ warnings.warn(f'minimum: {m} and maximum: {M} values are the same, scaling values to zero.')
+ # handle equal values
+ divisor = in_max - in_min
+ divisor[divisor == 0] = 1
+ # normalize
+ return (tensor - in_min) / divisor
+
+
+# ========================================================================= #
+# Argument Helper #
+# ========================================================================= #
+
+
+# float16 doesnt always work, rather convert to float32 first
+_ALLOWED_DTYPES = {
+ torch.float32, torch.float64,
+ torch.uint8,
+ torch.int, torch.int16, torch.int32, torch.int64,
+ torch.long,
+}
+
+
+@lru_cache()
+def _torch_to_images_normalise_args(in_tensor_shape: Tuple[int, ...], in_tensor_dtype: torch.dtype, in_dims: str, out_dims: str, in_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype]):
+ # check types
+ if not isinstance(in_dims, str): raise TypeError(f'in_dims must be of type: {str}, but got: {type(in_dims)}')
+ if not isinstance(out_dims, str): raise TypeError(f'out_dims must be of type: {str}, but got: {type(out_dims)}')
+ # normalise dim names
+ in_dims = in_dims.upper()
+ out_dims = out_dims.upper()
+ # check dim values
+ if sorted(in_dims) != sorted('CHW'): raise KeyError(f'in_dims contains the symbols: {repr(in_dims)}, must contain only permutations of: {repr("CHW")}')
+ if sorted(out_dims) != sorted('CHW'): raise KeyError(f'out_dims contains the symbols: {repr(out_dims)}, must contain only permutations of: {repr("CHW")}')
+ # get dimension indices
+ in_c_dim = in_dims.index('C') - len(in_dims)
+ out_c_dim = out_dims.index('C') - len(out_dims)
+ transpose_indices = tuple(in_dims.index(c) - len(in_dims) for c in out_dims)
+ # check image tensor
+ if len(in_tensor_shape) < 3:
+ raise ValueError(f'images must have 3 or more dimensions corresponding to: (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}')
+ if in_tensor_shape[in_c_dim] not in (1, 3):
+ raise ValueError(f'images do not have the correct number of channels for dim "C", required: 1 or 3. Input format is (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}')
+ # get default values
+ if in_dtype is None: in_dtype = in_tensor_dtype
+ if out_dtype is None: out_dtype = in_dtype
+ # check dtypes allowed
+ if in_dtype not in _ALLOWED_DTYPES: raise TypeError(f'in_dtype is not allowed, got: {repr(in_dtype)} must be one of: {list(_ALLOWED_DTYPES)}')
+ if out_dtype not in _ALLOWED_DTYPES: raise TypeError(f'out_dtype is not allowed, got: {repr(out_dtype)} must be one of: {list(_ALLOWED_DTYPES)}')
+ # done!
+ return transpose_indices, in_dtype, out_dtype, out_c_dim
+
+
+def _torch_channel_broadcast_scale_values(
+ in_min: MinMaxHint,
+ in_max: MinMaxHint,
+ in_dtype: torch.dtype,
+ dim: int,
+ ndim: int,
+) -> Tuple[List[Number], List[Number]]:
+ return __torch_channel_broadcast_scale_values(
+ in_min=tuple(np.array(in_min).reshape(-1).tolist()), # TODO: this is slow?
+ in_max=tuple(np.array(in_max).reshape(-1).tolist()), # TODO: this is slow?
+ in_dtype=in_dtype,
+ dim=dim,
+ ndim=ndim,
+ )
+
+@lru_cache()
+@torch.no_grad()
+def __torch_channel_broadcast_scale_values(
+ in_min: MinMaxHint,
+ in_max: MinMaxHint,
+ in_dtype: torch.dtype,
+ dim: int,
+ ndim: int,
+) -> Tuple[List[Number], List[Number]]:
+ # get the default values
+ in_min: np.ndarray = np.array((0.0 if in_dtype.is_floating_point else 0.0) if (in_min is None) else in_min)
+ in_max: np.ndarray = np.array((1.0 if in_dtype.is_floating_point else 255.0) if (in_max is None) else in_max)
+ # add missing axes
+ if in_min.ndim == 0: in_min = in_min[None]
+ if in_max.ndim == 0: in_max = in_max[None]
+ # checks
+ assert in_min.ndim == 1
+ assert in_max.ndim == 1
+ assert np.all(in_min <= in_max), f'min values are not <= the max values: {in_min} !<= {in_max}'
+ # normalize dim
+ dim = normalize_axis_index(dim, ndim=ndim)
+ # pad dim
+ r_pad = ndim - (dim + 1)
+ if r_pad > 0:
+ in_min = in_min[(...,) + ((None,)*r_pad)]
+ in_max = in_max[(...,) + ((None,)*r_pad)]
+ # done!
+ return in_min.tolist(), in_max.tolist()
+
+
+# ========================================================================= #
+# Image Conversion #
+# ========================================================================= #
+
+
+@torch.no_grad()
+def torch_to_images(
+ tensor: torch.Tensor,
+ in_dims: str = 'CHW', # we always treat numpy by default as HWC, and torch.Tensor as CHW
+ out_dims: str = 'HWC',
+ in_dtype: Optional[torch.dtype] = None,
+ out_dtype: Optional[torch.dtype] = torch.uint8,
+ clamp_mode: str = 'warn', # clamp, warn, error
+ always_rgb: bool = False,
+ in_min: Optional[MinMaxHint] = None,
+ in_max: Optional[MinMaxHint] = None,
+ to_numpy: bool = False,
+) -> Union[torch.Tensor, np.ndarray]:
+ """
+ Convert a batch of image-like tensors to images.
+ A batch in this case consists of an arbitrary number of dimensions of a tensor,
+ with the last 3 dimensions making up the actual images.
+
+ Process:
+ 1. check input dtype
+ 2. move axis
+ 3. normalize
+ 4. clamp values
+ 5. auto scale and convert
+ 6. convert to rgb
+ 7. check output dtype
+
+ example:
+ Convert a tensor of non-normalised images (..., C, H, W) to a
+ tensor of normalised and clipped images (..., H, W, C).
+ - integer dtypes are expected to be in the range [0, 255]
+ - float dtypes are expected to be in the range [0, 1]
+
+ # TODO: add support for uneven in/out dims, eg. in_dims="HW", out_dims="HWC"
+ """
+ # 0.a. check tensor
+ if not isinstance(tensor, torch.Tensor):
+ raise TypeError(f'images must be of type: {torch.Tensor}, got: {type(tensor)}')
+ # 0.b. get arguments
+ transpose_indices, in_dtype, out_dtype, out_c_dim = _torch_to_images_normalise_args(
+ in_tensor_shape=tuple(tensor.shape),
+ in_tensor_dtype=tensor.dtype,
+ in_dims=in_dims,
+ out_dims=out_dims,
+ in_dtype=in_dtype,
+ out_dtype=out_dtype,
+ )
+ # 1. check input dtype
+ if in_dtype != tensor.dtype:
+ raise TypeError(f'images dtype: {repr(tensor.dtype)} does not match in_dtype: {repr(in_dtype)}')
+ # 2. move axes
+ tensor = tensor.permute(*(i-tensor.ndim for i in range(tensor.ndim-3)), *transpose_indices)
+ # 3. normalise
+ if (in_min is not None) or (in_max is not None):
+ norm_dtype = (out_dtype if out_dtype.is_floating_point else torch.float32)
+ tensor = torch_image_normalize_channels(tensor, in_min=in_min, in_max=in_max, channel_dim=out_c_dim, out_dtype=norm_dtype)
+ # 4. clamp
+ tensor = torch_image_clamp(tensor, clamp_mode=clamp_mode)
+ # 5. auto scale and convert
+ tensor = torch_image_to_dtype(tensor, out_dtype=out_dtype)
+ # 6. convert to rgb
+ if always_rgb:
+ if tensor.shape[out_c_dim] == 1:
+ tensor = torch.repeat_interleave(tensor, 3, dim=out_c_dim) # torch.repeat is like np.tile, torch.repeat_interleave is like np.repeat
+ # 7. check output dtype
+ if out_dtype != tensor.dtype:
+ raise RuntimeError(f'[THIS IS A BUG!]: After conversion, images tensor dtype: {repr(tensor.dtype)} does not match out_dtype: {repr(in_dtype)}')
+ if torch.any(torch.isnan(tensor)):
+ raise RuntimeError('[THIS IS A BUG!]: After conversion, images contain NaN values!')
+ # convert to numpy
+ if to_numpy:
+ return tensor.detach().cpu().numpy()
+ return tensor
+
+
+def numpy_to_images(
+ ndarray: np.ndarray,
+ in_dims: str = 'HWC', # we always treat numpy by default as HWC, and torch.Tensor as CHW
+ out_dims: str = 'HWC',
+ in_dtype: Optional[Union[str, np.dtype]] = None,
+ out_dtype: Optional[Union[str, np.dtype]] = np.dtype('uint8'),
+ clamp_mode: str = 'warn', # clamp, warn, error
+ always_rgb: bool = False,
+ in_min: Optional[MinMaxHint] = None,
+ in_max: Optional[MinMaxHint] = None,
+) -> np.ndarray:
+ """
+ Convert a batch of image-like arrays to images.
+ A batch in this case consists of an arbitrary number of dimensions of an array,
+ with the last 3 dimensions making up the actual images.
+ - See the docs for: torch_to_images(...)
+ """
+ # convert numpy dtypes to torch
+ if in_dtype is not None: in_dtype = _NP_TO_TORCH_DTYPE[np.dtype(in_dtype)]
+ if out_dtype is not None: out_dtype = _NP_TO_TORCH_DTYPE[np.dtype(out_dtype)]
+ # convert back
+ array = torch_to_images(
+ tensor=torch.from_numpy(ndarray),
+ in_dims=in_dims,
+ out_dims=out_dims,
+ in_dtype=in_dtype,
+ out_dtype=out_dtype,
+ clamp_mode=clamp_mode,
+ always_rgb=always_rgb,
+ in_min=in_min,
+ in_max=in_max,
+ to_numpy=True,
+ )
+ # done!
+ return array
+
+
+def numpy_to_pil_images(
+ ndarray: np.ndarray,
+ in_dims: str = 'HWC', # we always treat numpy by default as HWC, and torch.Tensor as CHW
+ clamp_mode: str = 'warn',
+ always_rgb: bool = False,
+ in_min: Optional[MinMaxHint] = None,
+ in_max: Optional[MinMaxHint] = None,
+) -> Union[np.ndarray]:
+ """
+ Convert a numpy array containing images (..., C, H, W) to an array of PIL images (...,)
+ """
+ imgs = numpy_to_images(
+ ndarray=ndarray,
+ in_dims=in_dims,
+ out_dims='HWC',
+ in_dtype=None,
+ out_dtype='uint8',
+ clamp_mode=clamp_mode,
+ always_rgb=always_rgb,
+ in_min=in_min,
+ in_max=in_max,
+ )
+ # all the cases (even ndim == 3)... bravo numpy, bravo!
+ images = [Image.fromarray(imgs[idx]) for idx in np.ndindex(imgs.shape[:-3])]
+ images = np.array(images, dtype=object).reshape(imgs.shape[:-3])
+ # done
+ return images
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/disent/util/visualize/vis_model.py b/disent/util/visualize/vis_latents.py
similarity index 69%
rename from disent/util/visualize/vis_model.py
rename to disent/util/visualize/vis_latents.py
index b7735ce9..72b88d25 100644
--- a/disent/util/visualize/vis_model.py
+++ b/disent/util/visualize/vis_latents.py
@@ -23,13 +23,13 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
import logging
+from typing import Callable
+
import numpy as np
import torch
from disent.util import to_numpy
from disent.util.visualize import vis_util
-from disent.util.visualize.vis_util import make_animated_image_grid
-from disent.util.visualize.vis_util import reconstructions_to_images
log = logging.getLogger(__name__)
@@ -46,6 +46,7 @@
# CHANGES: #
# - extracted from original code #
# - was not split into functions in this was #
+# - TODO: convert these functions to torch.Tensors #
# ========================================================================= #
@@ -98,56 +99,56 @@ def _z_minmax_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
}
+def make_latent_zs_cycle(
+ base_z: torch.Tensor,
+ z_means: torch.Tensor,
+ z_logvars: torch.Tensor,
+ z_idx: int,
+ num_frames: int,
+ mode: str = 'minmax_interval_cycle',
+) -> torch.Tensor:
+ # get mode
+ if mode not in _LATENT_CYCLE_MODES_MAP:
+ raise KeyError(f'Unsupported mode: {repr(mode)} not in {set(_LATENT_CYCLE_MODES_MAP)}')
+ z_gen_func = _LATENT_CYCLE_MODES_MAP[mode]
+ # checks
+ assert base_z.ndim == 1
+ assert base_z.shape == z_means.shape[1:]
+ assert z_means.ndim == z_logvars.ndim == 2
+ assert z_means.shape == z_logvars.shape
+ assert len(z_means) > 1, f'not enough representations to average, number of z_means should be greater than 1, got: {z_means.shape}'
+ # make cycle
+ z_cycle = z_gen_func(to_numpy(base_z), to_numpy(z_means), to_numpy(z_logvars), z_idx, num_frames)
+ return torch.from_numpy(z_cycle)
+
+
# ========================================================================= #
# Visualise Latent Cycles #
# ========================================================================= #
-# TODO: this function should not convert output to images, it should just be
-# left as is. That way we don't need to pass in the recon_min and recon_max
-def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_animations=4, num_frames=20, decoder_device=None, recon_min=0., recon_max=1.) -> np.ndarray:
- assert len(z_means) > 1 and len(z_logvars) > 1, 'not enough samples to average'
- # convert
- z_means, z_logvars = to_numpy(z_means), to_numpy(z_logvars)
- # get mode
- if mode not in _LATENT_CYCLE_MODES_MAP:
- raise KeyError(f'Unsupported mode: {repr(mode)} not in {set(_LATENT_CYCLE_MODES_MAP)}')
- z_gen_func = _LATENT_CYCLE_MODES_MAP[mode]
+# TODO: this should be moved into the VAE and AE classes
+def make_decoded_latent_cycles(
+ decoder_func: Callable[[torch.Tensor], torch.Tensor],
+ z_means: torch.Tensor,
+ z_logvars: torch.Tensor,
+ mode: str = 'minmax_interval_cycle',
+ num_animations: int = 4,
+ num_frames: int = 20,
+ decoder_device=None,
+) -> torch.Tensor:
+ # generate multiple latent traversal visualisations
animations = []
- for i, base_z in enumerate(z_means[:num_animations]):
+ for i in range(num_animations):
frames = []
- for j in range(z_means.shape[1]):
- z = z_gen_func(base_z, z_means, z_logvars, j, num_frames)
+ for z_idx in range(z_means.shape[1]):
+ z = make_latent_zs_cycle(z_means[i], z_means, z_logvars, z_idx, num_frames, mode=mode)
z = torch.as_tensor(z, device=decoder_device)
frames.append(decoder_func(z))
- animations.append(frames)
- return reconstructions_to_images(animations, recon_min=recon_min, recon_max=recon_max)
-
-
-def latent_cycle_grid_animation(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_frames=21, pad=4, border=True, bg_color=0.5, decoder_device=None, tensor_style_channels=True, always_rgb=True, return_stills=False, to_uint8=False, recon_min=0., recon_max=1.) -> np.ndarray:
- # produce latent cycle animation & merge frames
- stills = latent_cycle(decoder_func, z_means, z_logvars, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=decoder_device, recon_min=recon_min, recon_max=recon_max)[0]
- # check and add missing channel if needed (convert greyscale to rgb images)
- if always_rgb:
- assert stills.shape[-1] in {1, 3}, f'Invalid number of image channels: {stills.shape} ({stills.shape[-1]})'
- if stills.shape[-1] == 1:
- stills = np.repeat(stills, 3, axis=-1)
- # create animation
- frames = make_animated_image_grid(stills, pad=pad, border=border, bg_color=bg_color)
- # move channels to end
- if tensor_style_channels:
- if return_stills:
- stills = np.transpose(stills, [0, 1, 4, 2, 3])
- frames = np.transpose(frames, [0, 3, 1, 2])
- # convert to uint8
- if to_uint8:
- if return_stills:
- stills = np.clip(stills*255, 0, 255).astype('uint8')
- frames = np.clip(frames*255, 0, 255).astype('uint8')
- # done!
- if return_stills:
- return frames, stills
- return frames
+ animations.append(torch.stack(frames, dim=0))
+ animations = torch.stack(animations, dim=0)
+ # return everything
+ return animations # (num_animations, z_size, num_frames, C, H, W)
# ========================================================================= #
diff --git a/disent/util/visualize/vis_util.py b/disent/util/visualize/vis_util.py
index 7d46abe2..0385481a 100644
--- a/disent/util/visualize/vis_util.py
+++ b/disent/util/visualize/vis_util.py
@@ -24,7 +24,11 @@
import logging
import warnings
+from functools import lru_cache
from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
from typing import Union
import numpy as np
@@ -52,7 +56,7 @@
}
-def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None):
+def make_image_grid(images: Sequence[np.ndarray], pad: int = 8, border: bool = True, bg_color=None, num_cols: Optional[int] = None):
"""
Convert a list of images into a single image that is a grid of those images.
:param images: list of input images, all the same size: (I, H, W, C) or (I, H, W)
@@ -65,12 +69,12 @@ def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None):
# first, second, third channels are the (H, W, C)
# get image sizes
img_shape, ndim = np.array(images[0].shape), images[0].ndim
- assert ndim == 2 or ndim == 3, 'images have wrong number of channels'
+ assert ndim == 2 or ndim == 3, f'images have wrong number of channels: {img_shape}'
assert np.all(img_shape == img.shape for img in images), 'Images are not the same shape!'
# get image size and channels
img_size = img_shape[:2]
if ndim == 3:
- assert (img_shape[2] == 1) or (img_shape[2] == 3), 'Invalid number of channels for an image.'
+ assert (img_shape[2] == 1) or (img_shape[2] == 3), f'Invalid number of channels for an image: {img_shape}'
# get bg color
if bg_color is None:
bg_color = _BG_COLOR_DTYPE_MAP[images[0].dtype]
@@ -90,7 +94,7 @@ def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None):
return grid
-def make_animated_image_grid(list_of_animated_images, pad=8, border=True, bg_color=None, num_cols=None):
+def make_animated_image_grid(list_of_animated_images: Sequence[np.ndarray], pad: int = 8, border: bool = True, bg_color=None, num_cols: Optional[int] = None):
"""
:param list_of_animated_images: list of input images, with the second dimension the number of frames: : (I, F, H, W, C) or (I, F, H, W)
:param pad: the number of pixels between images
@@ -114,7 +118,7 @@ def make_animated_image_grid(list_of_animated_images, pad=8, border=True, bg_col
# ========================================================================= #
-def _get_grid_size(n, num_cols=None):
+def _get_grid_size(n: int, num_cols: Optional[int] = None):
"""
Determine the number of rows and columns, given the total number of elements n.
- if num_cols is None: rows x cols is as square as possible
@@ -135,7 +139,7 @@ def _get_grid_size(n, num_cols=None):
# ========================================================================= #
-def _get_interval_factor_traversal(factor_size, num_frames, start_index=0):
+def _get_interval_factor_traversal(factor_size: int, num_frames: int, start_index: int = 0):
"""
Cycles through the state space in a single cycle.
eg. num_indices=5, num_frames=7 returns: [0,1,1,2,3,3,4]
@@ -147,29 +151,51 @@ def _get_interval_factor_traversal(factor_size, num_frames, start_index=0):
return grid
-def _get_cycle_factor_traversal(factor_size, num_frames):
+def _get_cycle_factor_traversal(factor_size: int, num_frames: int, start_index: int = 0):
"""
Cycles through the state space in a single cycle.
eg. num_indices=5, num_frames=7 returns: [0,1,3,4,3,2,1]
eg. num_indices=4, num_frames=7 returns: [0,1,2,3,2,2,0]
"""
- grid = _get_interval_factor_traversal(factor_size=factor_size, num_frames=num_frames)
+ assert start_index == 0, f'cycle factor traversal mode only supports start_index==0, got: {repr(start_index)}'
+ grid = _get_interval_factor_traversal(factor_size=factor_size, num_frames=num_frames, start_index=0)
grid = np.concatenate([grid[0::2], grid[1::2][::-1]])
return grid
+def _get_cycle_factor_traversal_from_start(factor_size: int, num_frames: int, start_index: int = 0, ends: bool = False):
+ all_idxs = np.array([
+ *range(start_index, factor_size - (1 if ends else 0)),
+ *reversed(range(0, factor_size)),
+ *range(1 if ends else 0, start_index),
+ ])
+ selected_idxs = _get_interval_factor_traversal(factor_size=len(all_idxs), num_frames=num_frames, start_index=0)
+ grid = all_idxs[selected_idxs]
+ # checks
+ assert all_idxs[0] == start_index, 'Please report this bug!'
+ assert grid[0] == start_index, 'Please report this bug!'
+ return grid
+
+
+def _get_cycle_factor_traversal_from_start_ends(factor_size: int, num_frames: int, start_index: int = 0, ends: bool = True):
+ return _get_cycle_factor_traversal_from_start(factor_size=factor_size, num_frames=num_frames, start_index=start_index, ends=ends)
+
+
+
_FACTOR_TRAVERSALS = {
'interval': _get_interval_factor_traversal,
'cycle': _get_cycle_factor_traversal,
+ 'cycle_from_start': _get_cycle_factor_traversal_from_start,
+ 'cycle_from_start_ends': _get_cycle_factor_traversal_from_start_ends,
}
-def get_idx_traversal(factor_size, num_frames, mode='interval'):
+def get_idx_traversal(factor_size: int, num_frames: int, mode: str = 'interval', start_index: int = 0):
try:
traversal_fn = _FACTOR_TRAVERSALS[mode]
except KeyError:
raise KeyError(f'Invalid factor traversal mode: {repr(mode)}')
- return traversal_fn(factor_size=factor_size, num_frames=num_frames)
+ return traversal_fn(factor_size=factor_size, num_frames=num_frames, start_index=start_index)
# ========================================================================= #
@@ -213,295 +239,5 @@ def cycle_interval(starting_value, num_frames, min_val, max_val):
# ========================================================================= #
-# Conversion/Util #
-# ========================================================================= #
-
-
-# TODO: this functionality is duplicated elsewhere!
-# TODO: similar functions exist: output_image, to_img, to_imgs, reconstructions_to_images
-def reconstructions_to_images(
- recon,
- mode: str = 'float',
- moveaxis: bool = True,
- recon_min: Union[float, List[float]] = 0.0,
- recon_max: Union[float, List[float]] = 1.0,
- warn_if_clipped: bool = True,
-) -> Union[np.ndarray, Image.Image]:
- """
- Convert a batch of reconstructions to images.
- A batch in this case consists of an arbitrary number of dimensions of an array,
- with the last 3 dimensions making up the actual image. For example: (..., channels, size, size)
-
- NOTE: This function might not be efficient for large amounts of
- data due to assertions and initial recursive conversions to a numpy array.
-
- NOTE: kornia has a similar function!
- """
- img = to_numpy(recon)
- # checks
- assert img.ndim >= 3
- assert img.dtype in (np.float32, np.float64)
- # move channels axis
- if moveaxis:
- img = np.moveaxis(img, -3, -1)
- # check min and max
- recon_min = np.array(recon_min)
- recon_max = np.array(recon_max)
- assert recon_min.shape == recon_max.shape
- assert recon_min.ndim in (0, 1) # supports channels or glbal min . max
- # scale image
- img = (img - recon_min) / (recon_max - recon_min)
- # check image bounds
- if warn_if_clipped:
- m, M = np.min(img), np.max(img)
- if m < 0 or M > 1:
- log.warning(f'images with range [{m}, {M}] have been clipped to the range [0, 1]')
- # do clipping
- img = np.clip(img, 0, 1)
- # convert
- if mode == 'float':
- return img
- elif mode == 'int':
- return np.uint8(img * 255)
- elif mode == 'pil':
- img = np.uint8(img * 255)
- # all the cases (even ndim == 3)... bravo numpy, bravo!
- images = [Image.fromarray(img[idx]) for idx in np.ndindex(img.shape[:-3])]
- images = np.array(images, dtype=object).reshape(img.shape[:-3])
- return images
- else:
- raise KeyError(f'Invalid mode: {repr(mode)} not in { {"float", "int", "pil"} }')
-
-
-# ========================================================================= #
-# Image Util #
-# ========================================================================= #
-
-
-# TODO: something like this should replace reconstructions_to_images above!
-# def torch_image_clamp(tensor: torch.Tensor, clamp_mode='warn') -> torch.Tensor:
-# # get dtype max value
-# dtype_M = 1 if tensor.dtype.is_floating_point else 255
-# # handle different modes
-# if clamp_mode in ('warn', 'error'):
-# m, M = tensor.min().cpu().numpy(), tensor.max().cpu().numpy()
-# if (0 < m) or (M > dtype_M):
-# if clamp_mode == 'warn':
-# warnings.warn(f'image values are out of bounds, expected values in the range: [0, {dtype_M}], received values in the range: {[m, M]}')
-# else:
-# raise ValueError(f'image values are out of bounds, expected values in the range: [0, {dtype_M}], received values in the range: {[m, M]}')
-# elif clamp_mode != 'clamp':
-# raise KeyError(f'invalid clamp mode: {repr(clamp_mode)}')
-# # clamp values
-# return torch.clamp(tensor, 0, dtype_M)
-#
-#
-# # float16 doesnt always work, rather convert to float32 first
-# _ALLOWED_DTYPES = {
-# torch.float32, torch.float64,
-# torch.uint8,
-# torch.int, torch.int16, torch.int32, torch.int64,
-# torch.long,
-# }
-#
-#
-# @lru_cache()
-# def _torch_to_images_normalise_args(in_tensor_shape: Tuple[int, ...], in_tensor_dtype: torch.dtype, in_dims: str, out_dims: str, in_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype]):
-# # check types
-# if not isinstance(in_dims, str): raise TypeError(f'in_dims must be of type: {str}, but got: {type(in_dims)}')
-# if not isinstance(out_dims, str): raise TypeError(f'out_dims must be of type: {str}, but got: {type(out_dims)}')
-# # normalise dim names
-# in_dims = in_dims.upper()
-# out_dims = out_dims.upper()
-# # check dim values
-# if sorted(in_dims) != sorted('CHW'): raise KeyError(f'in_dims contains the symbols: {repr(in_dims)}, must contain only permutations of: {repr("CHW")}')
-# if sorted(out_dims) != sorted('CHW'): raise KeyError(f'out_dims contains the symbols: {repr(out_dims)}, must contain only permutations of: {repr("CHW")}')
-# # get dimension indices
-# in_c_dim = in_dims.index('C') - len(in_dims)
-# transpose_indices = tuple(in_dims.index(c) - len(in_dims) for c in out_dims)
-# # check image tensor
-# if len(in_tensor_shape) < 3:
-# raise ValueError(f'images must have 3 or more dimensions corresponding to: (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}')
-# if in_tensor_shape[in_c_dim] not in (1, 3):
-# raise ValueError(f'images do not have the correct number of channels for dim "C", required: 1 or 3. Input format is (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}')
-# # get default values
-# if in_dtype is None: in_dtype = in_tensor_dtype
-# if out_dtype is None: out_dtype = in_dtype
-# # check dtypes allowed
-# if in_dtype not in _ALLOWED_DTYPES: raise TypeError(f'in_dtype is not allowed, got: {repr(in_dtype)} must be one of: {list(_ALLOWED_DTYPES)}')
-# if out_dtype not in _ALLOWED_DTYPES: raise TypeError(f'out_dtype is not allowed, got: {repr(out_dtype)} must be one of: {list(_ALLOWED_DTYPES)}')
-# # done!
-# return transpose_indices, in_dtype, out_dtype
-#
-#
-# def torch_to_images(
-# tensor: torch.Tensor,
-# in_dims: str = 'CHW',
-# out_dims: str = 'HWC',
-# in_dtype: Optional[torch.dtype] = None,
-# out_dtype: Optional[torch.dtype] = torch.uint8,
-# clamp_mode: str = 'warn', # clamp, warn, error
-# ) -> torch.Tensor:
-# """
-# Convert a batch of image-like tensors to images.
-# A batch in this case consists of an arbitrary number of dimensions of a tensor,
-# with the last 3 dimensions making up the actual images.
-#
-# example:
-# Convert a tensor of non-normalised images (..., C, H, W) to a
-# tensor of normalised and clipped images (..., H, W, C).
-# - integer dtypes are expected to be in the range [0, 255]
-# - float dtypes are expected to be in the range [0, 1]
-# """
-# if not isinstance(tensor, torch.Tensor):
-# raise TypeError(f'images must be of type: {torch.Tensor}, got: {type(tensor)}')
-# # check arguments
-# transpose_indices, in_dtype, out_dtype = _torch_to_images_normalise_args(
-# in_tensor_shape=tuple(tensor.shape), in_tensor_dtype=tensor.dtype,
-# in_dims=in_dims, out_dims=out_dims,
-# in_dtype=in_dtype, out_dtype=out_dtype,
-# )
-# # check inputs
-# if in_dtype != tensor.dtype:
-# raise TypeError(f'images dtype: {repr(tensor.dtype)} does not match in_dtype: {repr(in_dtype)}')
-# # convert images
-# with torch.no_grad():
-# # check that input values are in the correct range
-# # move axes
-# tensor = tensor.permute(*(i-tensor.ndim for i in range(tensor.ndim-3)), *transpose_indices)
-# # convert outputs
-# if in_dtype != out_dtype:
-# if in_dtype.is_floating_point and (not out_dtype.is_floating_point):
-# tensor = (tensor * 255).to(out_dtype)
-# elif (not in_dtype.is_floating_point) and out_dtype.is_floating_point:
-# tensor = tensor.to(out_dtype) / 255
-# else:
-# tensor = tensor.to(out_dtype)
-# # clamp
-# tensor = torch_image_clamp(tensor, clamp_mode=clamp_mode)
-# # check outputs
-# if out_dtype != tensor.dtype: # pragma: no cover
-# raise RuntimeError(f'[THIS IS A BUG! PLEASE REPORT THIS!]: After conversion, images tensor dtype: {repr(tensor.dtype)} does not match out_dtype: {repr(in_dtype)}')
-# # done
-# return tensor
-#
-#
-# def numpy_to_images(
-# ndarray: np.ndarray,
-# in_dims: str = 'CHW',
-# out_dims: str = 'HWC',
-# in_dtype: Optional[str, np.dtype] = None,
-# out_dtype: Optional[str, np.dtype] = np.dtype('uint8'),
-# clamp_mode: str = 'warn', # clamp, warn, error
-# ) -> np.ndarray:
-# """
-# Convert a batch of image-like arrays to images.
-# A batch in this case consists of an arbitrary number of dimensions of an array,
-# with the last 3 dimensions making up the actual images.
-# - See the docs for: torch_to_imgs(...)
-# """
-# # convert numpy dtypes to torch
-# if in_dtype is not None: in_dtype = getattr(torch, np.dtype(in_dtype).name)
-# if out_dtype is not None: out_dtype = getattr(torch, np.dtype(out_dtype).name)
-# # convert back
-# tensor = torch_to_images(tensor=torch.from_numpy(ndarray), in_dims=in_dims, out_dims=out_dims, in_dtype=in_dtype, out_dtype=out_dtype, clamp_mode=clamp_mode)
-# # done!
-# return tensor.numpy()
-#
-#
-# def numpy_to_pil_images(ndarray: np.ndarray, in_dims: str = 'CHW', clamp_mode: str = 'warn'):
-# """
-# Convert a numpy array containing images (..., C, H, W) to an array of PIL images (...,)
-# """
-# imgs = numpy_to_images(ndarray=ndarray, in_dims=in_dims, out_dims='HWC', in_dtype=None, out_dtype='uint8', clamp_mode=clamp_mode)
-# # all the cases (even ndim == 3)... bravo numpy, bravo!
-# images = [Image.fromarray(imgs[idx]) for idx in np.ndindex(imgs.shape[:-3])]
-# images = np.array(images, dtype=object).reshape(imgs.shape[:-3])
-# # done
-# return images
-#
-#
-# def test_torch_to_imgs():
-# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
-# inp_uint8 = (inp_float * 127 + 63).to(torch.uint8)
-# # check runs
-# out = torch_to_imgs(inp_float)
-# assert out.dtype == torch.uint8
-# out = torch_to_imgs(inp_uint8)
-# assert out.dtype == torch.uint8
-# out = torch_to_imgs(inp_float, in_dtype=None, out_dtype=None)
-# assert out.dtype == inp_float.dtype
-# out = torch_to_imgs(inp_uint8, in_dtype=None, out_dtype=None)
-# assert out.dtype == inp_uint8.dtype
-#
-#
-# def test_torch_to_imgs_permutations():
-# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
-# inp_uint8 = (inp_float * 127 + 63).to(torch.uint8)
-#
-# # general checks
-# def check_all(inputs, in_dtype=None):
-# float_results, int_results = [], []
-# for out_dtype in _ALLOWED_DTYPES:
-# out = torch_to_imgs(inputs, in_dtype=in_dtype, out_dtype=out_dtype)
-# (float_results if out_dtype.is_floating_point else int_results).append(torch.stack([
-# out.min().to(torch.float64), out.max().to(torch.float64), out.mean(dtype=torch.float64)
-# ]))
-# for a, b in zip(float_results[:-1], float_results[1:]): assert torch.allclose(a, b)
-# for a, b in zip(int_results[:-1], int_results[1:]): assert torch.allclose(a, b)
-#
-# # check type permutations
-# check_all(inp_float, torch.float32)
-# check_all(inp_uint8, torch.uint8)
-#
-#
-# def test_torch_to_imgs_preserve_type():
-# for dtype in _ALLOWED_DTYPES:
-# tensor = (torch.rand(8, 3, 64, 64) * (1 if dtype.is_floating_point else 255)).to(dtype)
-# out = torch_to_imgs(tensor, in_dtype=dtype, out_dtype=dtype, clamp=True)
-# assert out.dtype == dtype
-#
-#
-# def test_torch_to_imgs_args():
-# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
-#
-# # check tensor
-# with pytest.raises(TypeError, match="images tensor must be of type"):
-# torch_to_imgs(tensor=None)
-# with pytest.raises(ValueError, match='dim "C", required: 1 or 3'):
-# torch_to_imgs(tensor=torch.rand(8, 2, 16, 16, dtype=torch.float32))
-# with pytest.raises(ValueError, match='dim "C", required: 1 or 3'):
-# torch_to_imgs(tensor=torch.rand(8, 16, 16, 3, dtype=torch.float32))
-# with pytest.raises(ValueError, match='images tensor must have 3 or more dimensions corresponding to'):
-# torch_to_imgs(tensor=torch.rand(16, 16, dtype=torch.float32))
-#
-# # check dims
-# with pytest.raises(TypeError, match="in_dims must be of type"):
-# torch_to_imgs(inp_float, in_dims=None)
-# with pytest.raises(TypeError, match="out_dims must be of type"):
-# torch_to_imgs(inp_float, out_dims=None)
-# with pytest.raises(KeyError, match="in_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"):
-# torch_to_imgs(inp_float, in_dims='INVALID')
-# with pytest.raises(KeyError, match="out_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"):
-# torch_to_imgs(inp_float, out_dims='INVALID')
-# with pytest.raises(KeyError, match="in_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"):
-# torch_to_imgs(inp_float, in_dims='CHWW')
-# with pytest.raises(KeyError, match="out_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"):
-# torch_to_imgs(inp_float, out_dims='CHWW')
-#
-# # check dtypes
-# with pytest.raises(TypeError, match="images tensor dtype: torch.float32 does not match in_dtype: torch.uint8"):
-# torch_to_imgs(inp_float, in_dtype=torch.uint8)
-# with pytest.raises(TypeError, match='in_dtype is not allowed'):
-# torch_to_imgs(inp_float, in_dtype=torch.complex64)
-# with pytest.raises(TypeError, match='out_dtype is not allowed'):
-# torch_to_imgs(inp_float, out_dtype=torch.complex64)
-# with pytest.raises(TypeError, match='in_dtype is not allowed'):
-# torch_to_imgs(inp_float, in_dtype=torch.float16)
-# with pytest.raises(TypeError, match='out_dtype is not allowed'):
-# torch_to_imgs(inp_float, out_dtype=torch.float16)
-
-
-# ========================================================================= #
-# END #
+# END #
# ========================================================================= #
diff --git a/docs/img/traversals/traversal-transpose__cars3d.jpg b/docs/img/traversals/traversal-transpose__cars3d.jpg
new file mode 100644
index 00000000..80bac07e
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__cars3d.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg b/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg
new file mode 100644
index 00000000..7db78492
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg b/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg
new file mode 100644
index 00000000..1dae74d1
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__dsprites.jpg b/docs/img/traversals/traversal-transpose__dsprites.jpg
new file mode 100644
index 00000000..58a292b7
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__shapes3d.jpg b/docs/img/traversals/traversal-transpose__shapes3d.jpg
new file mode 100644
index 00000000..b8cc2a2f
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__shapes3d.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__smallnorb.jpg b/docs/img/traversals/traversal-transpose__smallnorb.jpg
new file mode 100644
index 00000000..80043344
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__smallnorb.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg b/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg
new file mode 100644
index 00000000..4204acf6
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg differ
diff --git a/docs/img/traversals/traversal-transpose__xy-object.jpg b/docs/img/traversals/traversal-transpose__xy-object.jpg
new file mode 100644
index 00000000..c50ae6d6
Binary files /dev/null and b/docs/img/traversals/traversal-transpose__xy-object.jpg differ
diff --git a/docs/img/xy-object-traversal.png b/docs/img/xy-object-traversal.png
deleted file mode 100644
index a26524ac..00000000
Binary files a/docs/img/xy-object-traversal.png and /dev/null differ
diff --git a/experiment/config/augment/basic.yaml b/experiment/config/augment/basic.yaml
deleted file mode 100644
index 2f44ce41..00000000
--- a/experiment/config/augment/basic.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-name: basic
-
-augment_cls:
- _target_: torchvision.transforms.RandomOrder
- transforms:
- - _target_: kornia.augmentation.ColorJitter
- p: 0.5
- brightness: 0.25
- contrast: 0.25
- saturation: 0
- hue: 0.15
-
-# THIS IS BUGGY ON BATCH
-# - _target_: kornia.augmentation.RandomSharpness
-# p: 0.5
-# sharpness: 0.5
-
- - _target_: kornia.augmentation.RandomCrop
- p: 0.5
- size: [64, 64]
- padding: 8
-
- - _target_: kornia.augmentation.RandomPerspective
- p: 0.5
- distortion_scale: 0.15
-
- - _target_: kornia.augmentation.RandomRotation
- p: 0.5
- degrees: 9
-
-# - _target_: kornia.augmentation.RandomResizedCrop
-# p: 0.5
-# size: [64, 64]
-# scale: [0.95, 1.05]
-# ratio: [0.95, 1.05]
-
-# THIS REPLACES MOST OF THE ABOVE BUT IT IS BUGGY ON BATCH
-# - _target_: kornia.augmentation.RandomAffine
-# p: 0.5
-# degrees: 10
-# translate: [0.14, 0.14]
-# scale: [0.95, 1.05]
-# shear: 5
diff --git a/experiment/config/augment/example.yaml b/experiment/config/augment/example.yaml
new file mode 100644
index 00000000..8c4ce29e
--- /dev/null
+++ b/experiment/config/augment/example.yaml
@@ -0,0 +1,8 @@
+name: basic
+
+augment_cls:
+ _target_: torchvision.transforms.ColorJitter
+ brightness: 0.1
+ contrast: 0.1
+ saturation: 0.1
+ hue: 0.1
diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml
index bb62ad0c..8267d491 100644
--- a/experiment/config/config.yaml
+++ b/experiment/config/config.yaml
@@ -1,4 +1,5 @@
defaults:
+ - _self_ # defaults lists override entries from this file!
# data
- sampling: default__bb
- dataset: xyobject
@@ -18,8 +19,8 @@ defaults:
- run_location: local
- run_launcher: local
- run_action: train
- # entries in this file override entries from default lists
- - _self_
+ # experiment
+ - run_plugins: default
settings:
job:
@@ -27,21 +28,20 @@ settings:
project: 'DELETE'
name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}'
seed: NULL
-
framework:
beta: 0.0316
recon_loss: mse
- loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling`
-
+ loss_reduction: mean # beta scaling
framework_opt:
latent_distribution: normal # only used by VAEs
-
model:
z_size: 25
weight_init: 'xavier_normal' # xavier_normal, default
-
dataset:
batch_size: 256
-
optimizer:
lr: 1e-3
+# TODO: https://pytorch-lightning.readthedocs.io/en/stable/common/weights_loading.html
+# checkpoint:
+# load_checkpoint: NULL # NULL or string
+# save_checkpoint: FALSE # boolean, save at end of run -- more advanced checkpointing can be done with a callback!
diff --git a/experiment/config/config_test.yaml b/experiment/config/config_test.yaml
index 63a66c8a..a24df0c6 100644
--- a/experiment/config/config_test.yaml
+++ b/experiment/config/config_test.yaml
@@ -1,11 +1,12 @@
defaults:
+ - _self_ # defaults lists override entries from this file!
# data
- sampling: default__bb
- dataset: xyobject
- - augment: none
+ - augment: example
# system
- framework: betavae
- - model: vae_conv64
+ - model: linear
# training
- optimizer: adam
- schedule: beta_cyclic
@@ -18,8 +19,8 @@ defaults:
- run_location: local_cpu
- run_launcher: local
- run_action: train
- # entries in this file override entries from default lists
- - _self_
+ # experiment
+ - run_plugins: default
settings:
job:
@@ -27,21 +28,16 @@ settings:
project: 'invalid'
name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}'
seed: NULL
-
framework:
beta: 0.0316
recon_loss: mse
- loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling`
-
+ loss_reduction: mean # beta scaling
framework_opt:
latent_distribution: normal # only used by VAEs
-
model:
z_size: 25
weight_init: 'xavier_normal' # xavier_normal, default
-
dataset:
batch_size: 5
-
optimizer:
lr: 1e-3
diff --git a/experiment/config/dataset/cars3d.yaml b/experiment/config/dataset/cars3d.yaml
index 9cb9055e..99e04de6 100644
--- a/experiment/config/dataset/cars3d.yaml
+++ b/experiment/config/dataset/cars3d.yaml
@@ -4,13 +4,12 @@ defaults:
name: cars3d
data:
- _target_: disent.dataset.data.Cars3dData
+ _target_: disent.dataset.data.Cars3d64Data
data_root: ${dsettings.storage.data_root}
prepare: ${dsettings.dataset.prepare}
transform:
_target_: disent.dataset.transform.ToImgTensorF32
- size: 64
mean: ${dataset.meta.vis_mean}
std: ${dataset.meta.vis_std}
diff --git a/experiment/config/dataset/monte_rollouts.yaml b/experiment/config/dataset/monte_rollouts.yaml
deleted file mode 100644
index 682eeb54..00000000
--- a/experiment/config/dataset/monte_rollouts.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-defaults:
- - _data_type_: episodes
-
-name: monte_rollouts
-
-data:
- _target_: disent.dataset.data.EpisodesDownloadZippedPickledData
- required_file: ${dsettings.storage.data_root}/episodes/monte.pkl
- download_url: 'https://raw.githubusercontent.com/nmichlo/uploads/main/monte_key.tar.xz'
- prepare: ${dsettings.dataset.prepare}
-
-transform:
- _target_: disent.dataset.transform.ToImgTensorF32
- size: [64, 64] # slightly squashed?
- mean: ${dataset.meta.vis_mean}
- std: ${dataset.meta.vis_std}
-
-meta:
- x_shape: [3, 64, 64] # [3, 210, 160]
- vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}"
- vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}"
diff --git a/experiment/config/dataset/smallnorb.yaml b/experiment/config/dataset/smallnorb.yaml
index 9dfbb8ec..6ffa56fc 100644
--- a/experiment/config/dataset/smallnorb.yaml
+++ b/experiment/config/dataset/smallnorb.yaml
@@ -4,14 +4,13 @@ defaults:
name: smallnorb
data:
- _target_: disent.dataset.data.SmallNorbData
+ _target_: disent.dataset.data.SmallNorb64Data
data_root: ${dsettings.storage.data_root}
prepare: ${dsettings.dataset.prepare}
is_test: False
transform:
_target_: disent.dataset.transform.ToImgTensorF32
- size: 64
mean: ${dataset.meta.vis_mean}
std: ${dataset.meta.vis_std}
diff --git a/experiment/config/framework/_input_mode_/triple.yaml b/experiment/config/framework/_input_mode_/triplet.yaml
similarity index 79%
rename from experiment/config/framework/_input_mode_/triple.yaml
rename to experiment/config/framework/_input_mode_/triplet.yaml
index 40f44980..13d895bb 100644
--- a/experiment/config/framework/_input_mode_/triple.yaml
+++ b/experiment/config/framework/_input_mode_/triplet.yaml
@@ -1,3 +1,3 @@
# controlled by the framework's defaults list
-name: triple
+name: triplet
num: 3
diff --git a/experiment/config/framework/adagvae_minimal_os.yaml b/experiment/config/framework/adagvae_minimal_os.yaml
index e7bcbf13..e5631398 100644
--- a/experiment/config/framework/adagvae_minimal_os.yaml
+++ b/experiment/config/framework/adagvae_minimal_os.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
diff --git a/experiment/config/framework/adavae.yaml b/experiment/config/framework/adavae.yaml
index d234dcbf..b0c77133 100644
--- a/experiment/config/framework/adavae.yaml
+++ b/experiment/config/framework/adavae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
# adavae
diff --git a/experiment/config/framework/adavae_os.yaml b/experiment/config/framework/adavae_os.yaml
index d8054386..f9225746 100644
--- a/experiment/config/framework/adavae_os.yaml
+++ b/experiment/config/framework/adavae_os.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
# adavae
diff --git a/experiment/config/framework/ae.yaml b/experiment/config/framework/ae.yaml
index cd410b6b..c1fc7f27 100644
--- a/experiment/config/framework/ae.yaml
+++ b/experiment/config/framework/ae.yaml
@@ -9,7 +9,7 @@ cfg:
recon_loss: ${settings.framework.recon_loss}
loss_reduction: ${settings.framework.loss_reduction}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
diff --git a/experiment/config/framework/betatcvae.yaml b/experiment/config/framework/betatcvae.yaml
index 32005237..ed4ccf73 100644
--- a/experiment/config/framework/betatcvae.yaml
+++ b/experiment/config/framework/betatcvae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-TcVae
beta: ${settings.framework.beta}
diff --git a/experiment/config/framework/betavae.yaml b/experiment/config/framework/betavae.yaml
index eb0f1540..031992cb 100644
--- a/experiment/config/framework/betavae.yaml
+++ b/experiment/config/framework/betavae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
diff --git a/experiment/config/framework/dfcvae.yaml b/experiment/config/framework/dfcvae.yaml
index 9a242f1d..6426366c 100644
--- a/experiment/config/framework/dfcvae.yaml
+++ b/experiment/config/framework/dfcvae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
# dfcvae
diff --git a/experiment/config/framework/dipvae.yaml b/experiment/config/framework/dipvae.yaml
index 4efebcf4..8bd42c05 100644
--- a/experiment/config/framework/dipvae.yaml
+++ b/experiment/config/framework/dipvae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
# DIP-VAE
diff --git a/experiment/config/framework/infovae.yaml b/experiment/config/framework/infovae.yaml
index e9b8234b..3867b2d4 100644
--- a/experiment/config/framework/infovae.yaml
+++ b/experiment/config/framework/infovae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Info-VAE
# info vae is not based off beta vae, but with
# the correct parameter choice this can equal the beta vae
diff --git a/experiment/config/framework/tae.yaml b/experiment/config/framework/tae.yaml
index 39fffd61..c5279846 100644
--- a/experiment/config/framework/tae.yaml
+++ b/experiment/config/framework/tae.yaml
@@ -9,7 +9,7 @@ cfg:
recon_loss: ${settings.framework.recon_loss}
loss_reduction: ${settings.framework.loss_reduction}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
# tvae: triplet stuffs
diff --git a/experiment/config/framework/tvae.yaml b/experiment/config/framework/tvae.yaml
index 601c52d8..baf9d5a4 100644
--- a/experiment/config/framework/tvae.yaml
+++ b/experiment/config/framework/tvae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
# Beta-VAE
beta: ${settings.framework.beta}
# tvae: triplet stuffs
diff --git a/experiment/config/framework/vae.yaml b/experiment/config/framework/vae.yaml
index 31ec83c4..4c4d022f 100644
--- a/experiment/config/framework/vae.yaml
+++ b/experiment/config/framework/vae.yaml
@@ -11,11 +11,10 @@ cfg:
# base vae
latent_distribution: ${settings.framework_opt.latent_distribution}
# disable various components
- disable_decoder: FALSE
+ detach_decoder: FALSE
disable_reg_loss: FALSE
disable_rec_loss: FALSE
disable_aug_loss: FALSE
- disable_posterior_scale: NULL
meta:
model_z_multiplier: 2
diff --git a/experiment/config/metrics/standard.yaml b/experiment/config/metrics/standard.yaml
deleted file mode 100644
index 49ba02de..00000000
--- a/experiment/config/metrics/standard.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-metric_list:
- - mig: {}
- - sap: {}
- - dci:
- every_n_steps: 7200
- on_final: TRUE
- - factor_vae:
- every_n_steps: 7200
- on_final: TRUE
-
-# these are the default settings, these can be placed in the list above
-default_on_final: TRUE
-default_on_train: TRUE
-default_every_n_steps: 2400
-default_begin_first_step: FALSE
diff --git a/experiment/config/model/linear.yaml b/experiment/config/model/linear.yaml
index 30e1ed1e..d63408c4 100644
--- a/experiment/config/model/linear.yaml
+++ b/experiment/config/model/linear.yaml
@@ -1,12 +1,18 @@
name: linear
-encoder_cls:
- _target_: disent.model.ae.EncoderLinear
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
- z_multiplier: ${framework.meta.model_z_multiplier}
-
-decoder_cls:
- _target_: disent.model.ae.DecoderLinear
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
+model_cls:
+ # weight initialisation
+ _target_: disent.nn.weights.init_model_weights
+ mode: ${settings.model.weight_init}
+ model:
+ # auto-encoder
+ _target_: disent.model.AutoEncoder
+ encoder:
+ _target_: disent.model.ae.EncoderLinear
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
+ z_multiplier: ${framework.meta.model_z_multiplier}
+ decoder:
+ _target_: disent.model.ae.DecoderLinear
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
diff --git a/experiment/config/model/norm_conv64.yaml b/experiment/config/model/norm_conv64.yaml
index 3068055b..29db4b6d 100644
--- a/experiment/config/model/norm_conv64.yaml
+++ b/experiment/config/model/norm_conv64.yaml
@@ -1,23 +1,29 @@
name: norm_conv64
-encoder_cls:
- _target_: disent.model.ae.EncoderConv64Norm
- x_shape: ${data.meta.x_shape}
- z_size: ${settings.model.z_size}
- z_multiplier: ${framework.meta.model_z_multiplier}
- activation: ${model.activation}
- norm: ${model.norm}
- norm_pre_act: ${model.norm_pre_act}
+model_cls:
+ # weight initialisation
+ _target_: disent.nn.weights.init_model_weights
+ mode: ${settings.model.weight_init}
+ model:
+ # auto-encoder
+ _target_: disent.model.AutoEncoder
+ encoder:
+ _target_: disent.model.ae.EncoderConv64Norm
+ x_shape: ${data.meta.x_shape}
+ z_size: ${settings.model.z_size}
+ z_multiplier: ${framework.meta.model_z_multiplier}
+ activation: ${model.meta.activation}
+ norm: ${model.meta.norm}
+ norm_pre_act: ${model.meta.norm_pre_act}
+ decoder:
+ _target_: disent.model.ae.DecoderConv64Norm
+ x_shape: ${data.meta.x_shape}
+ z_size: ${settings.model.z_size}
+ activation: ${model.meta.activation}
+ norm: ${model.meta.norm}
+ norm_pre_act: ${model.meta.norm_pre_act}
-decoder_cls:
- _target_: disent.model.ae.DecoderConv64Norm
- x_shape: ${data.meta.x_shape}
- z_size: ${settings.model.z_size}
- activation: ${model.activation}
- norm: ${model.norm}
- norm_pre_act: ${model.norm_pre_act}
-
-# vars
-activation: swish # leaky_relu, relu
-norm: layer # batch, instance, layer, layer_chn, none
-norm_pre_act: TRUE
+meta:
+ activation: swish # leaky_relu, relu
+ norm: layer # batch, instance, layer, layer_chn, none
+ norm_pre_act: TRUE
diff --git a/experiment/config/model/vae_conv64.yaml b/experiment/config/model/vae_conv64.yaml
index a05d00c0..b5e674e9 100644
--- a/experiment/config/model/vae_conv64.yaml
+++ b/experiment/config/model/vae_conv64.yaml
@@ -1,12 +1,18 @@
name: vae_conv64
-encoder_cls:
- _target_: disent.model.ae.EncoderConv64
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
- z_multiplier: ${framework.meta.model_z_multiplier}
-
-decoder_cls:
- _target_: disent.model.ae.DecoderConv64
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
+model_cls:
+ # weight initialisation
+ _target_: disent.nn.weights.init_model_weights
+ mode: ${settings.model.weight_init}
+ model:
+ # auto-encoder
+ _target_: disent.model.AutoEncoder
+ encoder:
+ _target_: disent.model.ae.EncoderConv64
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
+ z_multiplier: ${framework.meta.model_z_multiplier}
+ decoder:
+ _target_: disent.model.ae.DecoderConv64
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
diff --git a/experiment/config/model/vae_fc.yaml b/experiment/config/model/vae_fc.yaml
index 6a68f40d..69c23a63 100644
--- a/experiment/config/model/vae_fc.yaml
+++ b/experiment/config/model/vae_fc.yaml
@@ -1,12 +1,18 @@
name: vae_fc
-encoder_cls:
- _target_: disent.model.ae.EncoderFC
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
- z_multiplier: ${framework.meta.model_z_multiplier}
-
-decoder_cls:
- _target_: disent.model.ae.DecoderFC
- x_shape: ${dataset.meta.x_shape}
- z_size: ${settings.model.z_size}
+model_cls:
+ # weight initialisation
+ _target_: disent.nn.weights.init_model_weights
+ mode: ${settings.model.weight_init}
+ model:
+ # auto-encoder
+ _target_: disent.model.AutoEncoder
+ encoder:
+ _target_: disent.model.ae.EncoderFC
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
+ z_multiplier: ${framework.meta.model_z_multiplier}
+ decoder:
+ _target_: disent.model.ae.DecoderFC
+ x_shape: ${dataset.meta.x_shape}
+ z_size: ${settings.model.z_size}
diff --git a/experiment/config/run_callbacks/all.yaml b/experiment/config/run_callbacks/all.yaml
index fd4c88a6..a8f24320 100644
--- a/experiment/config/run_callbacks/all.yaml
+++ b/experiment/config/run_callbacks/all.yaml
@@ -1,42 +1,22 @@
# @package _global_
callbacks:
+
latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
seed: 7777
every_n_steps: 3600
- mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
seed: 7777
every_n_steps: 3600
traversal_repeats: 100
begin_first_step: TRUE
-
-# correlation:
-# repeats_per_factor: 10
-# every_n_steps: 7200
-
-# latent_cycle:
-# _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
-# seed: 7777
-# every_n_steps: 7200
-# begin_first_step: FALSE
-# mode: 'minmax_interval_cycle' # minmax_interval_cycle, fitted_gaussian_cycle
-# recon_min: ${dataset.vis_min}
-# recon_max: ${dataset.vis_max}
-# recon_mean: ${dataset.vis_mean}
-# recon_std: ${dataset.vis_std}
-#
-# gt_dists:
-# _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
-# seed: 7777
-# every_n_steps: 7200
-# begin_first_step: TRUE
-# traversal_repeats: 100
-# plt_block_size: 1.25
-# plt_show: False
-# plt_transpose: False
-# log_wandb: TRUE
-# batch_size: ${cfg.dataset.batch_size}
-# include_factor_dists: TRUE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/test.yaml b/experiment/config/run_callbacks/test.yaml
index 255cda8a..3e680425 100644
--- a/experiment/config/run_callbacks/test.yaml
+++ b/experiment/config/run_callbacks/test.yaml
@@ -1,18 +1,22 @@
# @package _global_
callbacks:
+
latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
seed: 7777
every_n_steps: 3
- mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
begin_first_step: FALSE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
seed: 7777
every_n_steps: 4
traversal_repeats: 3
begin_first_step: FALSE
-
-# correlation:
-# repeats_per_factor: 3
-# every_n_steps: 5
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis.yaml b/experiment/config/run_callbacks/vis.yaml
index 2087779c..a8f24320 100644
--- a/experiment/config/run_callbacks/vis.yaml
+++ b/experiment/config/run_callbacks/vis.yaml
@@ -1,14 +1,22 @@
# @package _global_
callbacks:
+
latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
seed: 7777
every_n_steps: 3600
- mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
seed: 7777
every_n_steps: 3600
traversal_repeats: 100
begin_first_step: TRUE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis_debug.yaml b/experiment/config/run_callbacks/vis_debug.yaml
new file mode 100644
index 00000000..3791ae53
--- /dev/null
+++ b/experiment/config/run_callbacks/vis_debug.yaml
@@ -0,0 +1,22 @@
+# @package _global_
+
+callbacks:
+
+ latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
+ seed: 7777
+ every_n_steps: 600
+ begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
+
+ gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
+ seed: 7777
+ every_n_steps: 600
+ traversal_repeats: 50
+ begin_first_step: FALSE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis_fast.yaml b/experiment/config/run_callbacks/vis_fast.yaml
index 3df24a15..4c1b3802 100644
--- a/experiment/config/run_callbacks/vis_fast.yaml
+++ b/experiment/config/run_callbacks/vis_fast.yaml
@@ -1,14 +1,22 @@
# @package _global_
callbacks:
+
latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
seed: 7777
every_n_steps: 1800
- mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
seed: 7777
every_n_steps: 1800
traversal_repeats: 100
begin_first_step: TRUE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis_quick.yaml b/experiment/config/run_callbacks/vis_quick.yaml
new file mode 100644
index 00000000..d37d8805
--- /dev/null
+++ b/experiment/config/run_callbacks/vis_quick.yaml
@@ -0,0 +1,22 @@
+# @package _global_
+
+callbacks:
+
+ latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
+ seed: 7777
+ every_n_steps: 600
+ begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
+
+ gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
+ seed: 7777
+ every_n_steps: 1800
+ traversal_repeats: 50
+ begin_first_step: FALSE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis_skip_first.yaml b/experiment/config/run_callbacks/vis_skip_first.yaml
new file mode 100644
index 00000000..883ffb0d
--- /dev/null
+++ b/experiment/config/run_callbacks/vis_skip_first.yaml
@@ -0,0 +1,22 @@
+# @package _global_
+
+callbacks:
+
+ latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
+ seed: 7777
+ every_n_steps: 3600
+ begin_first_step: FALSE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
+
+ gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
+ seed: 7777
+ every_n_steps: 3600
+ traversal_repeats: 100
+ begin_first_step: FALSE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_callbacks/vis_slow.yaml b/experiment/config/run_callbacks/vis_slow.yaml
index 83f0516b..6dca0c0f 100644
--- a/experiment/config/run_callbacks/vis_slow.yaml
+++ b/experiment/config/run_callbacks/vis_slow.yaml
@@ -1,14 +1,22 @@
# @package _global_
callbacks:
+
latent_cycle:
+ _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback
seed: 7777
every_n_steps: 7200
- mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
begin_first_step: TRUE
+ mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle'
+ log_wandb: ${logging.wandb.enabled}
+ recon_mean: ${dataset.meta.vis_mean}
+ recon_std: ${dataset.meta.vis_std}
gt_dists:
+ _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback
seed: 7777
every_n_steps: 7200
traversal_repeats: 100
begin_first_step: TRUE
+ log_wandb: ${logging.wandb.enabled}
+ batch_size: ${settings.dataset.batch_size}
diff --git a/experiment/config/run_length/longmed.yaml b/experiment/config/run_length/longmed.yaml
new file mode 100644
index 00000000..7ffd07a5
--- /dev/null
+++ b/experiment/config/run_length/longmed.yaml
@@ -0,0 +1,5 @@
+# @package _global_
+
+trainer:
+ max_epochs: 86400
+ max_steps: 86400
diff --git a/experiment/config/run_location/local.yaml b/experiment/config/run_location/local.yaml
index cc788508..8c2f249e 100644
--- a/experiment/config/run_location/local.yaml
+++ b/experiment/config/run_location/local.yaml
@@ -2,22 +2,21 @@
dsettings:
trainer:
- cuda: NULL # `NULL` tries to use CUDA if it is available. `TRUE` forces cuda to be used!
+ cuda: NULL # `NULL` tries to use CUDA if it is available, otherwise defaulting to the CPU
storage:
logs_dir: 'logs'
data_root: '/tmp/${oc.env:USER}/datasets'
dataset:
- gpu_augment: FALSE
prepare: TRUE
try_in_memory: TRUE
-trainer:
+datamodule:
+ gpu_augment: FALSE
prepare_data_per_node: TRUE
-
-dataloader:
- num_workers: 8
- pin_memory: ${dsettings.trainer.cuda}
- batch_size: ${settings.dataset.batch_size}
+ dataloader:
+ num_workers: 8
+ pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster!
+ batch_size: ${settings.dataset.batch_size}
hydra:
job:
diff --git a/experiment/config/run_location/local_cpu.yaml b/experiment/config/run_location/local_cpu.yaml
index d1bcd61a..07ccfb84 100644
--- a/experiment/config/run_location/local_cpu.yaml
+++ b/experiment/config/run_location/local_cpu.yaml
@@ -2,22 +2,21 @@
dsettings:
trainer:
- cuda: FALSE
+ cuda: FALSE # The job will only use the CPU
storage:
logs_dir: 'logs'
data_root: '/tmp/${oc.env:USER}/datasets'
dataset:
- gpu_augment: FALSE
prepare: TRUE
try_in_memory: TRUE
-trainer:
+datamodule:
+ gpu_augment: FALSE
prepare_data_per_node: TRUE
-
-dataloader:
- num_workers: 8
- pin_memory: ${dsettings.trainer.cuda}
- batch_size: ${settings.dataset.batch_size}
+ dataloader:
+ num_workers: 8
+ pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster!
+ batch_size: ${settings.dataset.batch_size}
hydra:
job:
diff --git a/experiment/config/run_location/cluster.yaml b/experiment/config/run_location/local_gpu.yaml
similarity index 59%
rename from experiment/config/run_location/cluster.yaml
rename to experiment/config/run_location/local_gpu.yaml
index 9194493f..c166c948 100644
--- a/experiment/config/run_location/cluster.yaml
+++ b/experiment/config/run_location/local_gpu.yaml
@@ -2,26 +2,21 @@
dsettings:
trainer:
- cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly
+ cuda: TRUE # `TRUE` forces cuda to be used. The job fails if cuda is not available!
storage:
logs_dir: 'logs'
data_root: '/tmp/${oc.env:USER}/datasets'
dataset:
- gpu_augment: FALSE
prepare: TRUE
try_in_memory: TRUE
- launcher:
- partition: batch
- array_parallelism: 16
- exclude: "mscluster93,mscluster94,mscluster97,mscluster99"
-trainer:
+datamodule:
+ gpu_augment: FALSE
prepare_data_per_node: TRUE
-
-dataloader:
- num_workers: 8
- pin_memory: ${dsettings.trainer.cuda}
- batch_size: ${settings.dataset.batch_size}
+ dataloader:
+ num_workers: 8
+ pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster!
+ batch_size: ${settings.dataset.batch_size}
hydra:
job:
diff --git a/experiment/config/run_location/stampede_shr.yaml b/experiment/config/run_location/stampede_shr.yaml
deleted file mode 100644
index 05ba8626..00000000
--- a/experiment/config/run_location/stampede_shr.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-# @package _global_
-
-dsettings:
- trainer:
- cuda: NULL # auto-detect cuda, some nodes are not configured correctly
- storage:
- logs_dir: 'logs'
- data_root: '${oc.env:HOME}/downloads/datasets' # WE NEED TO BE VERY CAREFUL ABOUT USING A SHARED DRIVE
- dataset:
- gpu_augment: FALSE
- prepare: FALSE # WE MUST PREPARE DATA MANUALLY BEFOREHAND
- try_in_memory: TRUE
- launcher:
- partition: stampede
- array_parallelism: 16
- exclude: "mscluster93,mscluster94,mscluster97,mscluster99"
-
-trainer:
- prepare_data_per_node: TRUE
-
-dataloader:
- num_workers: 16
- pin_memory: ${dsettings.trainer.cuda}
- batch_size: ${settings.dataset.batch_size}
-
-hydra:
- job:
- name: 'disent'
- run:
- dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}'
- sweep:
- dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}'
- subdir: '${hydra.job.id}' # hydra.job.id is not available for dir
diff --git a/experiment/config/run_location/stampede_tmp.yaml b/experiment/config/run_location/stampede_tmp.yaml
deleted file mode 100644
index d596b335..00000000
--- a/experiment/config/run_location/stampede_tmp.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-# @package _global_
-
-dsettings:
- trainer:
- cuda: NULL # auto-detect cuda, some nodes are not configured correctly
- storage:
- logs_dir: 'logs'
- data_root: '/tmp/${oc.env:USER}/datasets'
- dataset:
- gpu_augment: FALSE
- prepare: TRUE
- try_in_memory: TRUE
- launcher:
- partition: stampede
- array_parallelism: 16
- exclude: "mscluster93,mscluster94,mscluster97,mscluster99"
-
-trainer:
- prepare_data_per_node: TRUE
-
-dataloader:
- num_workers: 16
- pin_memory: ${dsettings.trainer.cuda}
- batch_size: ${settings.dataset.batch_size}
-
-hydra:
- job:
- name: 'disent'
- run:
- dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}'
- sweep:
- dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}'
- subdir: '${hydra.job.id}' # hydra.job.id is not available for dir
diff --git a/experiment/config/run_logging/none.yaml b/experiment/config/run_logging/none.yaml
index 2a99d3c4..a9c656ff 100644
--- a/experiment/config/run_logging/none.yaml
+++ b/experiment/config/run_logging/none.yaml
@@ -6,13 +6,14 @@ defaults:
trainer:
log_every_n_steps: 50
- flush_logs_every_n_steps: 100
- progress_bar_refresh_rate: 0 # disable the builtin progress bar
+ enable_progress_bar: FALSE # disable the builtin progress bar
callbacks:
progress:
+ _target_: disent.util.lightning.callbacks.LoggerProgressCallback
interval: 5
logging:
wandb:
enabled: FALSE
+ logger: NULL
diff --git a/experiment/config/run_logging/wandb.yaml b/experiment/config/run_logging/wandb.yaml
index a72cb887..53c7f12c 100644
--- a/experiment/config/run_logging/wandb.yaml
+++ b/experiment/config/run_logging/wandb.yaml
@@ -6,19 +6,26 @@ defaults:
trainer:
log_every_n_steps: 100
- flush_logs_every_n_steps: 200
- progress_bar_refresh_rate: 0 # disable the builtin progress bar
+ enable_progress_bar: FALSE # disable the builtin progress bar
callbacks:
progress:
+ _target_: disent.util.lightning.callbacks.LoggerProgressCallback
interval: 15
logging:
wandb:
enabled: TRUE
+ logger:
+ _target_: pytorch_lightning.loggers.WandbLogger
offline: FALSE
- entity: '${settings.job.user}'
- project: '${settings.job.project}'
- name: '${settings.job.name}'
- group: null
+ entity: ${settings.job.user}
+ project: ${settings.job.project}
+ name: ${settings.job.name}
+ group: NULL
tags: []
+ save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd
+ # https://docs.wandb.ai/guides/track/launch#init-start-error
+ settings:
+ _target_: wandb.Settings
+ start_method: "fork" # fork: linux/macos, thread: google colab
diff --git a/experiment/config/run_logging/wandb_fast.yaml b/experiment/config/run_logging/wandb_fast.yaml
index 75c305ee..f0f978b9 100644
--- a/experiment/config/run_logging/wandb_fast.yaml
+++ b/experiment/config/run_logging/wandb_fast.yaml
@@ -6,19 +6,26 @@ defaults:
trainer:
log_every_n_steps: 50
- flush_logs_every_n_steps: 100
- progress_bar_refresh_rate: 0 # disable the builtin progress bar
+ enable_progress_bar: FALSE # disable the builtin progress bar
callbacks:
progress:
+ _target_: disent.util.lightning.callbacks.LoggerProgressCallback
interval: 5
logging:
wandb:
enabled: TRUE
+ logger:
+ _target_: pytorch_lightning.loggers.WandbLogger
offline: FALSE
- entity: '${settings.job.user}'
- project: '${settings.job.project}'
- name: '${settings.job.name}'
- group: null
+ entity: ${settings.job.user}
+ project: ${settings.job.project}
+ name: ${settings.job.name}
+ group: NULL
tags: []
+ save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd
+ # https://docs.wandb.ai/guides/track/launch#init-start-error
+ settings:
+ _target_: wandb.Settings
+ start_method: "fork" # fork: linux/macos, thread: google colab
diff --git a/experiment/config/run_logging/wandb_fast_offline.yaml b/experiment/config/run_logging/wandb_fast_offline.yaml
index 372e480a..0095a2c7 100644
--- a/experiment/config/run_logging/wandb_fast_offline.yaml
+++ b/experiment/config/run_logging/wandb_fast_offline.yaml
@@ -6,19 +6,26 @@ defaults:
trainer:
log_every_n_steps: 50
- flush_logs_every_n_steps: 100
- progress_bar_refresh_rate: 0 # disable the builtin progress bar
+ enable_progress_bar: FALSE # disable the builtin progress bar
callbacks:
progress:
+ _target_: disent.util.lightning.callbacks.LoggerProgressCallback
interval: 5
logging:
wandb:
enabled: TRUE
+ logger:
+ _target_: pytorch_lightning.loggers.WandbLogger
offline: TRUE
- entity: '${settings.job.user}'
- project: '${settings.job.project}'
- name: '${settings.job.name}'
- group: null
+ entity: ${settings.job.user}
+ project: ${settings.job.project}
+ name: ${settings.job.name}
+ group: NULL
tags: []
+ save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd
+ # https://docs.wandb.ai/guides/track/launch#init-start-error
+ settings:
+ _target_: wandb.Settings
+ start_method: "fork" # fork: linux/macos, thread: google colab
diff --git a/experiment/config/run_logging/wandb_slow.yaml b/experiment/config/run_logging/wandb_slow.yaml
index f7f4a49c..5718572f 100644
--- a/experiment/config/run_logging/wandb_slow.yaml
+++ b/experiment/config/run_logging/wandb_slow.yaml
@@ -6,19 +6,26 @@ defaults:
trainer:
log_every_n_steps: 200
- flush_logs_every_n_steps: 400
- progress_bar_refresh_rate: 0 # disable the builtin progress bar
+ enable_progress_bar: FALSE # disable the builtin progress bar
callbacks:
progress:
+ _target_: disent.util.lightning.callbacks.LoggerProgressCallback
interval: 30
logging:
wandb:
enabled: TRUE
+ logger:
+ _target_: pytorch_lightning.loggers.WandbLogger
offline: FALSE
- entity: '${settings.job.user}'
- project: '${settings.job.project}'
- name: '${settings.job.name}'
- group: null
+ entity: ${settings.job.user}
+ project: ${settings.job.project}
+ name: ${settings.job.name}
+ group: NULL
tags: []
+ save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd
+ # https://docs.wandb.ai/guides/track/launch#init-start-error
+ settings:
+ _target_: wandb.Settings
+ start_method: "fork" # fork: linux/macos, thread: google colab
diff --git a/experiment/config/run_plugins/default.yaml b/experiment/config/run_plugins/default.yaml
new file mode 100644
index 00000000..c608f5f3
--- /dev/null
+++ b/experiment/config/run_plugins/default.yaml
@@ -0,0 +1,6 @@
+# @package _global_
+
+# call the listed functions here before the experiment is started
+# - this can be used to register functions or metrics to the disent registry for example!
+experiment:
+ plugins: []
diff --git a/experiment/config/sampling/gt_dist__manhat_scaled.yaml b/experiment/config/sampling/gt_dist__manhat_scaled.yaml
index 5fb96993..5de529ab 100644
--- a/experiment/config/sampling/gt_dist__manhat_scaled.yaml
+++ b/experiment/config/sampling/gt_dist__manhat_scaled.yaml
@@ -4,5 +4,5 @@ defaults:
name: gt_dist__manhat_scaled
-triplet_sample_mode: "manhattan_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled
+triplet_sample_mode: 'manhattan_scaled' # random, factors, manhattan, manhattan_scaled, combined, combined_scaled
triplet_swap_chance: 0
diff --git a/experiment/config/schedule/adavae_down_all.yaml b/experiment/config/schedule/adavae_down_all.yaml
deleted file mode 100644
index 8d4ac7f5..00000000
--- a/experiment/config/schedule/adavae_down_all.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-name: averaging_decrease__all
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # ada triplet
- r_end: 0.0 # triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # loss active
- r_end: 0.0 # loss inactive
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.5 # ada weighted triplet
- r_end: 1.0 # normal triplet
- ada_thresh_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # all averaged, should this not be 0.5 the recommended value
- r_end: 0.0 # none averaged
diff --git a/experiment/config/schedule/adavae_down_ratio.yaml b/experiment/config/schedule/adavae_down_ratio.yaml
deleted file mode 100644
index e3816fe1..00000000
--- a/experiment/config/schedule/adavae_down_ratio.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-name: averaging_decrease__ratio
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # ada triplet
- r_end: 0.0 # triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # loss active
- r_end: 0.0 # loss inactive
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.5 # ada weighted triplet
- r_end: 1.0 # normal triplet
diff --git a/experiment/config/schedule/adavae_down_thresh.yaml b/experiment/config/schedule/adavae_down_thresh.yaml
deleted file mode 100644
index 0ca660ac..00000000
--- a/experiment/config/schedule/adavae_down_thresh.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-name: averaging_decrease__thresh
-
-schedule_items:
- ada_thresh_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # all averaged, should this not be 0.5 the recommended value
- r_end: 0.0 # none averaged
diff --git a/experiment/config/schedule/adavae_up_all.yaml b/experiment/config/schedule/adavae_up_all.yaml
deleted file mode 100644
index 58cdd98d..00000000
--- a/experiment/config/schedule/adavae_up_all.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-name: averaging_increase__all
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # triplet
- r_end: 1.0 # ada triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # loss inactive
- r_end: 1.0 # loss active
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # normal triplet
- r_end: 0.5 # ada weighted triplet
- ada_thresh_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # none averaged
- r_end: 1.0 # all averaged, should this not be 0.5 the recommended value
diff --git a/experiment/config/schedule/adavae_up_all_full.yaml b/experiment/config/schedule/adavae_up_all_full.yaml
deleted file mode 100644
index 471a8da4..00000000
--- a/experiment/config/schedule/adavae_up_all_full.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-name: averaging_increase__all_full
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # triplet
- r_end: 1.0 # ada triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # loss inactive
- r_end: 1.0 # loss active
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # normal triplet
- r_end: 0.0 # ada weighted triplet
- ada_thresh_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # none averaged
- r_end: 1.0 # all averaged, should this not be 0.5 the recommended value
diff --git a/experiment/config/schedule/adavae_up_ratio.yaml b/experiment/config/schedule/adavae_up_ratio.yaml
deleted file mode 100644
index 79e3e6c3..00000000
--- a/experiment/config/schedule/adavae_up_ratio.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-name: averaging_increase__ratio
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # triplet
- r_end: 1.0 # ada triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # loss inactive
- r_end: 1.0 # loss active
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # normal triplet
- r_end: 0.5 # ada weighted triplet
diff --git a/experiment/config/schedule/adavae_up_ratio_full.yaml b/experiment/config/schedule/adavae_up_ratio_full.yaml
deleted file mode 100644
index 5b7fabe3..00000000
--- a/experiment/config/schedule/adavae_up_ratio_full.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-name: averaging_increase__ratio_full
-
-schedule_items:
- adat_triplet_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # triplet
- r_end: 1.0 # ada triplet
- adat_triplet_soft_scale:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # loss inactive
- r_end: 1.0 # loss active
- adat_triplet_share_scale: # reversed compared to others
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 1.0 # normal triplet
- r_end: 0.0 # ada weighted triplet
diff --git a/experiment/config/schedule/adavae_up_thresh.yaml b/experiment/config/schedule/adavae_up_thresh.yaml
deleted file mode 100644
index 7c012a32..00000000
--- a/experiment/config/schedule/adavae_up_thresh.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-name: averaging_increase__thresh
-
-schedule_items:
- ada_thresh_ratio:
- _target_: disent.schedule.LinearSchedule
- start_step: 0
- end_step: ${trainer.max_steps}
- r_start: 0.0 # none averaged
- r_end: 1.0 # all averaged, should this not be 0.5 the recommended value
diff --git a/experiment/run.py b/experiment/run.py
index bfe57545..c5827093 100644
--- a/experiment/run.py
+++ b/experiment/run.py
@@ -25,6 +25,8 @@
import logging
import os
from datetime import datetime
+from typing import Callable
+from typing import Optional
import hydra
import pytorch_lightning as pl
@@ -34,22 +36,19 @@
from omegaconf import DictConfig
from omegaconf import ListConfig
from omegaconf import OmegaConf
-from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning import Callback
+from pytorch_lightning.callbacks import ModelSummary
+from pytorch_lightning.loggers import LightningLoggerBase
-from disent import metrics
-from disent.frameworks import DisentConfigurable
+import disent.registry as R
from disent.frameworks import DisentFramework
-from disent.model import AutoEncoder
-from disent.nn.weights import init_model_weights
+from disent.util.lightning.callbacks import VaeMetricLoggingCallback
from disent.util.seeds import seed
-from disent.util.strings.fmt import make_box_str
from disent.util.strings import colors as c
-from disent.util.lightning.callbacks import LoggerProgressCallback
-from disent.util.lightning.callbacks import VaeMetricLoggingCallback
-from disent.util.lightning.callbacks import VaeLatentCycleLoggingCallback
-from disent.util.lightning.callbacks import VaeGtDistsLoggingCallback
+from disent.util.strings.fmt import make_box_str
+
from experiment.util.hydra_data import HydraDataModule
-from experiment.util.run_utils import log_error_and_exit
+from experiment.util.hydra_main import hydra_main
from experiment.util.run_utils import safe_unset_debug_logger
from experiment.util.run_utils import safe_unset_debug_trainer
from experiment.util.run_utils import set_debug_logger
@@ -64,6 +63,17 @@
# ========================================================================= #
+def hydra_register_disent_plugins(cfg):
+ # TODO: there should be a plugin mechanism for disent?
+ if cfg.experiment.plugins:
+ log.info('Running experiment plugins:')
+ for plugin in cfg.experiment.plugins:
+ log.info(f'* registering: {plugin}')
+ hydra.utils.instantiate(dict(_target_=plugin))
+ else:
+ log.info('No experiment plugins were listed. Register these under the `experiment.plugins` in the config, which lists targets of functions.')
+
+
def hydra_get_gpus(cfg) -> int:
use_cuda = cfg.dsettings.trainer.cuda
# check cuda values
@@ -85,7 +95,7 @@ def hydra_get_gpus(cfg) -> int:
def hydra_check_data_paths(cfg):
- prepare_data_per_node = cfg.trainer.prepare_data_per_node
+ prepare_data_per_node = cfg.datamodule.prepare_data_per_node
data_root = cfg.dsettings.storage.data_root
# check relative paths
if not os.path.isabs(data_root):
@@ -98,78 +108,32 @@ def hydra_check_data_paths(cfg):
)
if prepare_data_per_node:
log.error(
- f'trainer.prepare_data_per_node={repr(prepare_data_per_node)} but dsettings.storage.data_root='
+ f'datamodule.prepare_data_per_node={repr(prepare_data_per_node)} but dsettings.storage.data_root='
f'{repr(data_root)} is a relative path which may be an error! Try specifying an'
f' absolute path that is guaranteed to be unique from each node, eg. default_settings.storage.data_root=/tmp/dataset'
)
raise RuntimeError(f'default_settings.storage.data_root={repr(data_root)} is a relative path!')
-def hydra_make_logger(cfg):
- # make wandb logger
- backend = cfg.logging.wandb
- if backend.enabled:
- log.info('Initialising Weights & Biases Logger')
- return WandbLogger(
- offline=backend.offline,
- entity=backend.entity, # cometml: workspace
- project=backend.project, # cometml: project_name
- name=backend.name, # cometml: experiment_name
- group=backend.group, # experiment group
- tags=backend.tags, # experiment tags
- save_dir=hydra.utils.to_absolute_path(cfg.dsettings.storage.logs_dir), # relative to hydra's original cwd
- )
- # don't return a logger
- return None # LoggerCollection([...]) OR DummyLogger(...)
-
-
-def _callback_make_progress(cfg, callback_cfg):
- return LoggerProgressCallback(
- interval=callback_cfg.interval
- )
-
-
-def _callback_make_latent_cycle(cfg, callback_cfg):
- if cfg.logging.wandb.enabled:
- # checks
- if not (('vis_min' in cfg.dataset and 'vis_max' in cfg.dataset) or ('vis_mean' in cfg.dataset and 'vis_std' in cfg.dataset)):
- log.warning('dataset does not have visualisation ranges specified, set `vis_min` & `vis_max` OR `vis_mean` & `vis_std`')
- # this currently only supports WANDB logger
- return VaeLatentCycleLoggingCallback(
- seed = callback_cfg.seed,
- every_n_steps = callback_cfg.every_n_steps,
- begin_first_step = callback_cfg.begin_first_step,
- mode = callback_cfg.mode,
- # recon_min = cfg.data.meta.vis_min,
- # recon_max = cfg.data.meta.vis_max,
- recon_mean = cfg.dataset.meta.vis_mean,
- recon_std = cfg.dataset.meta.vis_std,
- )
+def hydra_check_data_meta(cfg):
+ # checks
+ if (cfg.dataset.meta.vis_mean is None) or (cfg.dataset.meta.vis_std is None):
+ log.warning(f'Dataset has no normalisation values... Are you sure this is correct?')
+ log.warning(f'* dataset.meta.vis_mean: {cfg.dataset.meta.vis_mean}')
+ log.warning(f'* dataset.meta.vis_std: {cfg.dataset.meta.vis_std}')
else:
- log.warning('latent_cycle callback is not being used because wandb is not enabled!')
- return None
-
-
-def _callback_make_gt_dists(cfg, callback_cfg):
- return VaeGtDistsLoggingCallback(
- seed = callback_cfg.seed,
- every_n_steps = callback_cfg.every_n_steps,
- traversal_repeats = callback_cfg.traversal_repeats,
- begin_first_step = callback_cfg.begin_first_step,
- plt_block_size = 1.25,
- plt_show = False,
- plt_transpose = False,
- log_wandb = True,
- batch_size = cfg.settings.dataset.batch_size,
- include_factor_dists = True,
- )
+ log.info(f'Dataset has normalisation values!')
+ log.info(f'* dataset.meta.vis_mean: {cfg.dataset.meta.vis_mean}')
+ log.info(f'* dataset.meta.vis_std: {cfg.dataset.meta.vis_std}')
-_CALLBACK_MAKERS = {
- 'progress': _callback_make_progress,
- 'latent_cycle': _callback_make_latent_cycle,
- 'gt_dists': _callback_make_gt_dists,
-}
+def hydra_make_logger(cfg) -> Optional[LightningLoggerBase]:
+ logger = hydra.utils.instantiate(cfg.logging.logger)
+ if logger:
+ log.info(f'Initialised Logger: {logger}')
+ else:
+ log.warning(f'No Logger Utilised!')
+ return logger
def hydra_get_callbacks(cfg) -> list:
@@ -177,21 +141,16 @@ def hydra_get_callbacks(cfg) -> list:
# add all callbacks
for name, item in cfg.callbacks.items():
# custom callback handling vs instantiation
- if '_target_' in item:
- name = f'{name} ({item._target_})'
- callback = hydra.utils.instantiate(item)
- else:
- callback = _CALLBACK_MAKERS[name](cfg, item)
+ callback = hydra.utils.instantiate(item)
+ assert isinstance(callback, Callback), f'instantiated callback is not an instance of {Callback}, got: {callback}'
# add to callbacks list
- if callback is not None:
- log.info(f'made callback: {name}')
- callbacks.append(callback)
- else:
- log.info(f'skipped callback: {name}')
+ log.info(f'made callback: {name} ({item._target_})')
+ callbacks.append(callback)
return callbacks
def hydra_get_metric_callbacks(cfg) -> list:
+ # TODO: simplify this, make better use of the config!
callbacks = []
# set default values used later
default_every_n_steps = cfg.metrics.default_every_n_steps
@@ -211,8 +170,8 @@ def hydra_get_metric_callbacks(cfg) -> list:
# check values
assert isinstance(metric, (dict, DictConfig)), f'settings for entry in metric list is not a dictionary, got type: {type(settings)} or value: {repr(settings)}'
# make metrics
- train_metric = [metrics.FAST_METRICS[name]] if settings.get('on_train', default_on_train) else None
- final_metric = [metrics.DEFAULT_METRICS[name]] if settings.get('on_final', default_on_final) else None
+ train_metric = [R.METRICS[name].compute_fast] if settings.get('on_train', default_on_train) else None
+ final_metric = [R.METRICS[name].compute] if settings.get('on_final', default_on_final) else None
# add the metric callback
if final_metric or train_metric:
callbacks.append(VaeMetricLoggingCallback(
@@ -224,54 +183,49 @@ def hydra_get_metric_callbacks(cfg) -> list:
return callbacks
-def hydra_register_schedules(module: DisentFramework, cfg):
- # check the type
- schedule_items = cfg.schedule.schedule_items
- assert isinstance(schedule_items, (dict, DictConfig)), f'`schedule.schedule_items` must be a dictionary, got type: {type(schedule_items)} with value: {repr(schedule_items)}'
- # add items
- if schedule_items:
- log.info(f'Registering Schedules:')
- for target, schedule in schedule_items.items():
- module.register_schedule(target, hydra.utils.instantiate(schedule), logging=True)
-
+def hydra_create_framework(cfg, gpu_batch_augment: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> DisentFramework:
+ # create framework
+ assert str.endswith(cfg.framework.cfg['_target_'], '.cfg'), f'`cfg.framework.cfg._target_` does not end with ".cfg", got: {repr(cfg.framework.cfg["_target_"])}'
+ framework_cls = hydra.utils.get_class(cfg.framework.cfg['_target_'][:-len(".cfg")])
+ framework: DisentFramework = framework_cls(
+ model=hydra.utils.instantiate(cfg.model.model_cls),
+ cfg=hydra.utils.instantiate(cfg.framework.cfg, _convert_='all'), # DisentConfigurable -- convert all OmegaConf objects to python equivalents, eg. DictConfig -> dict
+ batch_augment=gpu_batch_augment,
+ )
-def hydra_create_and_update_framework_config(cfg) -> DisentConfigurable.cfg:
- # create framework config - this is also kinda hacky
- # - we need instantiate_recursive because of optimizer_kwargs,
- # otherwise the dictionary is left as an OmegaConf dict
- framework_cfg: DisentConfigurable.cfg = hydra.utils.instantiate(cfg.framework.cfg)
- # warn if some of the cfg variables were not overridden
- missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.cfg.keys())))
+ # check if some cfg variables were not overridden
+ missing_keys = sorted(set(framework.cfg.get_keys()) - (set(cfg.framework.cfg.keys())))
if missing_keys:
log.warning(f'{c.RED}Framework {repr(cfg.framework.name)} is missing config keys for:{c.RST}')
for k in missing_keys:
log.warning(f'{c.RED}{repr(k)}{c.RST}')
- # return config
- return framework_cfg
-
-
-def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cfg):
- # specific handling for experiment, this is HACKY!
- # - not supported normally, we need to instantiate to get the class (is there hydra support for this?)
- framework_cfg.optimizer = hydra.utils.get_class(framework_cfg.optimizer)
- framework_cfg.optimizer_kwargs = dict(framework_cfg.optimizer_kwargs)
- # get framework path
- assert str.endswith(cfg.framework.cfg._target_, '.cfg'), f'`cfg.framework.cfg._target_` does not end with ".cfg", got: {repr(cfg.framework.cfg._target_)}'
- framework_cls = hydra.utils.get_class(cfg.framework.cfg._target_[:-len(".cfg")])
- # create model
- model = AutoEncoder(
- encoder=hydra.utils.instantiate(cfg.model.encoder_cls),
- decoder=hydra.utils.instantiate(cfg.model.decoder_cls),
- )
- # initialise the model
- model = init_model_weights(model, mode=cfg.settings.model.weight_init)
- # create framework
- return framework_cls(
- model=model,
- cfg=framework_cfg,
- batch_augment=datamodule.batch_augment, # apply augmentations to batch on GPU which can be faster than via the dataloader
- )
+ # register schedules to the framework
+ schedule_items = cfg.schedule.schedule_items
+ assert isinstance(schedule_items, (dict, DictConfig)), f'`schedule.schedule_items` must be a dictionary, got type: {type(schedule_items)} with value: {repr(schedule_items)}'
+ if schedule_items:
+ log.info(f'Registering Schedules:')
+ for target, schedule in schedule_items.items():
+ framework.register_schedule(target, hydra.utils.instantiate(schedule), logging=True)
+
+ return framework
+
+
+def hydra_make_datamodule(cfg):
+ return HydraDataModule(
+ data = cfg.dataset.data, # from: dataset
+ transform = cfg.dataset.transform, # from: dataset
+ augment = cfg.augment.augment_cls, # from: augment
+ sampler = cfg.sampling._sampler_.sampler_cls, # from: sampling
+ # from: run_location
+ using_cuda = cfg.dsettings.trainer.cuda,
+ dataloader_kwargs = cfg.datamodule.dataloader,
+ augment_on_gpu = cfg.datamodule.gpu_augment,
+ prepare_data_per_node = cfg.datamodule.prepare_data_per_node,
+ # from: framework.meta
+ return_indices = cfg.framework.meta.get('requires_indices', False),
+ return_factors = cfg.framework.meta.get('requires_factors', False),
+ )
# ========================================================================= #
# ACTIONS #
@@ -284,15 +238,18 @@ def action_prepare_data(cfg: DictConfig):
log.info(f'Starting run at time: {time_string}')
# deterministic seed
seed(cfg.settings.job.seed)
+ # register plugins
+ hydra_register_disent_plugins(cfg)
# print useful info
log.info(f"Current working directory : {os.getcwd()}")
log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}")
# check data preparation
hydra_check_data_paths(cfg)
+ hydra_check_data_meta(cfg)
# print the config
log.info(f'Dataset Config Is:\n{make_box_str(OmegaConf.to_yaml({"dataset": cfg.dataset}))}')
# prepare data
- datamodule = HydraDataModule(cfg)
+ datamodule = hydra_make_datamodule(cfg)
datamodule.prepare_data()
@@ -303,8 +260,9 @@ def action_train(cfg: DictConfig):
log.info(f'Starting run at time: {time_string}')
# -~-~-~-~-~-~-~-~-~-~-~-~- #
-
# cleanup from old runs:
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
+
try:
safe_unset_debug_trainer()
safe_unset_debug_logger()
@@ -313,75 +271,76 @@ def action_train(cfg: DictConfig):
pass
# -~-~-~-~-~-~-~-~-~-~-~-~- #
-
- # deterministic seed
- seed(cfg.settings.job.seed)
-
- # -~-~-~-~-~-~-~-~-~-~-~-~- #
- # INITIALISE & SETDEFAULT IN CONFIG
+ # SETUP
# -~-~-~-~-~-~-~-~-~-~-~-~- #
# create trainer loggers & callbacks & initialise error messages
logger = set_debug_logger(hydra_make_logger(cfg))
+ # deterministic seed
+ seed(cfg.settings.job.seed)
+ # register plugins
+ hydra_register_disent_plugins(cfg)
# print useful info
log.info(f"Current working directory : {os.getcwd()}")
log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}")
-
- # check CUDA setting
+ # checks
gpus = hydra_get_gpus(cfg)
-
- # check data preparation
hydra_check_data_paths(cfg)
+ hydra_check_data_meta(cfg)
- # TRAINER CALLBACKS
- callbacks = [
- *hydra_get_callbacks(cfg),
- *hydra_get_metric_callbacks(cfg),
- ]
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
+ # INITIALISE
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
# HYDRA MODULES
- datamodule = HydraDataModule(cfg)
- framework_cfg = hydra_create_and_update_framework_config(cfg)
- framework = hydra_create_framework(framework_cfg, datamodule, cfg)
-
- # register schedules
- hydra_register_schedules(framework, cfg)
+ datamodule = hydra_make_datamodule(cfg)
+ framework = hydra_create_framework(cfg, gpu_batch_augment=datamodule.gpu_batch_augment)
+ # trainer default kwargs
# Setup Trainer
trainer = set_debug_trainer(pl.Trainer(
+ # cannot override these
logger=logger,
- callbacks=callbacks,
gpus=gpus,
- # we do this here too so we don't run the final
- # metrics, even through we check for it manually.
- terminate_on_nan=True,
- # TODO: re-enable this in future... something is not compatible
- # with saving/checkpointing models + allow enabling from the
- # config. Seems like something cannot be pickled?
- checkpoint_callback=False,
- # additional trainer kwargs
- **cfg.trainer,
+ callbacks=[
+ *hydra_get_callbacks(cfg),
+ *hydra_get_metric_callbacks(cfg),
+ ModelSummary(max_depth=2), # override default ModelSummary
+ ],
+ # additional kwargs from the config
+ **{
+ **dict(
+ detect_anomaly=False, # this should only be enabled for debugging torch and finding NaN values, slows down execution, not by much though?
+ enable_checkpointing=False, # TODO: enable this in future
+ ),
+ **cfg.trainer, # overrides
+ }
))
# -~-~-~-~-~-~-~-~-~-~-~-~- #
- # BEGIN TRAINING
+ # DEBUG
# -~-~-~-~-~-~-~-~-~-~-~-~- #
# get config sections
print_cfg, boxed_pop = dict(cfg), lambda *keys: make_box_str(OmegaConf.to_yaml({k: print_cfg.pop(k) for k in keys} if keys else print_cfg))
+ cfg_str_exp = boxed_pop('action', 'experiment')
cfg_str_logging = boxed_pop('logging', 'callbacks', 'metrics')
- cfg_str_dataset = boxed_pop('dataset', 'sampling', 'augment')
+ cfg_str_dataset = boxed_pop('dataset', 'datamodule', 'sampling', 'augment')
cfg_str_system = boxed_pop('framework', 'model', 'schedule')
cfg_str_settings = boxed_pop('dsettings', 'settings')
cfg_str_other = boxed_pop()
# print config sections
- log.info(f'Final Config For Action: {cfg.action}\n\nLOGGING:{cfg_str_logging}\nDATASET:{cfg_str_dataset}\nSYSTEM:{cfg_str_system}\nTRAINER:{cfg_str_other}\nSETTINGS:{cfg_str_settings}')
+ log.info(f'Final Config For Action: {cfg.action}\n\nEXPERIMENT:{cfg_str_exp}\nLOGGING:{cfg_str_logging}\nDATASET:{cfg_str_dataset}\nSYSTEM:{cfg_str_system}\nTRAINER:{cfg_str_other}\nSETTINGS:{cfg_str_settings}')
- # save hparams TODO: is this a pytorch lightning bug? The trainer should automatically save these if hparams is set?
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
+ # BEGIN TRAINING
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
+
+ # save hparams
framework.hparams.update(cfg)
if trainer.logger:
- trainer.logger.log_hyperparams(framework.hparams)
+ trainer.logger.log_hyperparams(framework.hparams) # TODO: is this a pytorch lightning bug? The trainer should automatically save these if hparams is set?
# fit the model
# -- if an error/signal occurs while pytorch lightning is
@@ -389,15 +348,14 @@ def action_train(cfg: DictConfig):
trainer.fit(framework, datamodule=datamodule)
# -~-~-~-~-~-~-~-~-~-~-~-~- #
-
# cleanup this run
+ # -~-~-~-~-~-~-~-~-~-~-~-~- #
+
try:
wandb.finish()
except:
pass
- # -~-~-~-~-~-~-~-~-~-~-~-~- #
-
# available actions
ACTIONS = {
@@ -422,42 +380,13 @@ def run_action(cfg: DictConfig):
# ========================================================================= #
-# path to root directory containing configs
-CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'config'))
-# root config existing inside `CONFIG_ROOT`, with '.yaml' appended.
-CONFIG_NAME = 'config'
-
-
if __name__ == '__main__':
-
- # register a custom OmegaConf resolver that allows us to put in a ${exit:msg} that exits the program
- # - if we don't register this, the program will still fail because we have an unknown
- # resolver. This just prettifies the output.
- class ConfigurationError(Exception):
- pass
-
- def _error_resolver(msg: str):
- raise ConfigurationError(msg)
-
- OmegaConf.register_new_resolver('exit', _error_resolver)
-
- @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
- def hydra_main(cfg: DictConfig):
- try:
- run_action(cfg)
- except Exception as e:
- log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=True)
- except:
- log_error_and_exit(err_type='experiment error', err_msg='', exc_info=True)
-
- try:
- hydra_main()
- except KeyboardInterrupt as e:
- log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False)
- except Exception as e:
- log_error_and_exit(err_type='hydra error', err_msg=str(e), exc_info=True)
- except:
- log_error_and_exit(err_type='hydra error', err_msg='', exc_info=True)
+ # launch the action
+ hydra_main(
+ callback=run_action,
+ config_name='config',
+ log_level=logging.INFO,
+ )
# ========================================================================= #
diff --git a/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py b/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py
new file mode 100644
index 00000000..5a14483e
--- /dev/null
+++ b/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py
@@ -0,0 +1,33 @@
+"""
+This file is currently hacky, and hopefully temporary! See:
+https://github.com/facebookresearch/hydra/issues/2001
+"""
+
+import logging
+import os
+
+from hydra.core.config_search_path import ConfigSearchPath
+from hydra.plugins.search_path_plugin import SearchPathPlugin
+
+
+log = logging.getLogger(__name__)
+
+
+class DisentExperimentSearchPathPlugin(SearchPathPlugin):
+
+ def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
+ from experiment.util.hydra_main import _DISENT_CONFIG_DIRS
+ # find paths
+ paths = [
+ *os.environ.get('DISENT_CONFIGS_PREPEND', '').split(';'),
+ *_DISENT_CONFIG_DIRS,
+ *os.environ.get('DISENT_CONFIGS_APPEND', '').split(';'),
+ ]
+ # print information
+ log.info(f' [disent-search-path-plugin]: Activated hydra plugin: {self.__class__.__name__}')
+ log.info(f' [disent-search-path-plugin]: To register more search paths, adjust the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables!')
+ # add paths
+ for path in paths:
+ if path:
+ log.info(f' [disent-search-path] - {repr(path)}')
+ search_path.append(provider='disent-searchpath-plugin', path=os.path.abspath(path))
diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py
index daa4c191..2bf575b6 100644
--- a/experiment/util/hydra_data.py
+++ b/experiment/util/hydra_data.py
@@ -24,6 +24,9 @@
import logging
import warnings
+from typing import Any
+from typing import Dict
+from typing import Optional
import hydra
import torch.utils.data
@@ -81,37 +84,54 @@
class HydraDataModule(pl.LightningDataModule):
- def __init__(self, hparams: DictConfig):
+ def __init__(
+ self,
+ data: Dict[str, Any], # = dataset.data
+ sampler: Dict[str, Any], # = sampling._sampler_.sampler_cls
+ transform: Optional[Dict[str, Any]] = None, # = dataset.transform
+ augment: Optional[Dict[str, Any]] = None, # = augment.augment_cls
+ dataloader_kwargs: Optional[Dict[str, Any]] = None, # = dataloader
+ augment_on_gpu: bool = False, # = dsettings.dataset.gpu_augment
+ using_cuda: Optional[bool] = False, # = self.hparams.dsettings.trainer.cuda
+ prepare_data_per_node: bool = True, # DataHooks.prepare_data_per_node
+ return_indices: bool = False, # = framework.meta.requires_indices
+ return_factors: bool = False, # = framework.meta.requires_factors
+ ):
super().__init__()
- # support pytorch lightning < 1.4
- if not hasattr(self, 'hparams'):
- self.hparams = DictConfig(hparams)
- else:
- self.hparams.update(hparams)
+ # OVERRIDE:
+ self.prepare_data_per_node = prepare_data_per_node
+ # save hparams
+ self.save_hyperparameters()
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# transform: prepares data from datasets
- self.data_transform = hydra.utils.instantiate(self.hparams.dataset.transform)
+ self.data_transform = hydra.utils.instantiate(transform)
assert (self.data_transform is None) or callable(self.data_transform)
# input_transform_aug: augment data for inputs, then apply input_transform
- self.input_transform = hydra.utils.instantiate(self.hparams.augment.augment_cls)
- assert (self.input_transform is None) or callable(self.input_transform)
+ self.input_transform = hydra.utils.instantiate(augment)
+ assert (self.input_transform is None) or callable(self.input_transform) # should be: `Callable[[torch.Tensor], torch.Tensor]`
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# batch_augment: augments transformed data for inputs, should be applied across a batch
# which version of the dataset we need to use if GPU augmentation is enabled or not.
# - corresponds to below in train_dataloader()
- if self.hparams.dsettings.dataset.gpu_augment:
- # TODO: this is outdated!
- self.batch_augment = DisentDatasetTransform(transform=self.input_transform)
- warnings.warn('`gpu_augment=True` is outdated and may no longer be equivalent to `gpu_augment=False`')
+ if augment_on_gpu:
+ self._gpu_batch_augment = DisentDatasetTransform(transform=self.input_transform)
+ warnings.warn('`augment_on_gpu=True` is outdated and may no longer be equivalent to `augment_on_gpu=False`')
else:
- self.batch_augment = None
+ self._gpu_batch_augment = None
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# datasets initialised in setup()
self.dataset_train_noaug: DisentDataset = None
self.dataset_train_aug: DisentDataset = None
+ @property
+ def gpu_batch_augment(self) -> Optional[DisentDatasetTransform]:
+ return self._gpu_batch_augment
+
def prepare_data(self) -> None:
# *NB* Do not set model parameters here.
# - Instantiate data once to download and prepare if needed.
# - trainer.prepare_data_per_node affects this functions behavior per node.
- data = dict(self.hparams.dataset.data)
+ data = dict(self.hparams.data)
if 'in_memory' in data:
del data['in_memory']
# create the data
@@ -124,11 +144,11 @@ def prepare_data(self) -> None:
def setup(self, stage=None) -> None:
# ground truth data
log.info(f'Data - Instance')
- data = hydra.utils.instantiate(self.hparams.dataset.data)
+ data = hydra.utils.instantiate(self.hparams.data)
# Wrap the data for the framework some datasets need triplets, pairs, etc.
# Augmentation is done inside the frameworks so that it can be done on the GPU, otherwise things are very slow.
- self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=None)
- self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=self.input_transform)
+ self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampler), transform=self.data_transform, augment=None, return_indices=self.hparams.return_indices, return_factors=self.hparams.return_factors)
+ self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampler), transform=self.data_transform, augment=self.input_transform, return_indices=self.hparams.return_indices, return_factors=self.hparams.return_factors)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Training Dataset:
@@ -149,21 +169,25 @@ def train_dataloader(self):
"""
# Select which version of the dataset we need to use if GPU augmentation is enabled or not.
# - corresponds to above in __init__()
- if self.hparams.dsettings.dataset.gpu_augment:
+ if self.hparams.augment_on_gpu:
dataset = self.dataset_train_noaug
else:
dataset = self.dataset_train_aug
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# get default kwargs
default_kwargs = {
'shuffle': True,
# This should usually be TRUE if cuda is enabled.
# About 20% faster with the xysquares dataset, RTX 2060 Rev. A, and Intel i7-3930K
- 'pin_memory': self.hparams.dsettings.trainer.cuda
+ 'pin_memory': self.hparams.using_cuda,
}
# get config kwargs
- kwargs = self.hparams.dataloader
+ kwargs = self.hparams.dataloader_kwargs
+ if not kwargs:
+ kwargs = {}
# check required keys
if ('batch_size' not in kwargs) or ('num_workers' not in kwargs):
raise KeyError(f'`dataset.dataloader` must contain keys: ["batch_size", "num_workers"], got: {sorted(kwargs.keys())}')
+ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ #
# create dataloader
return torch.utils.data.DataLoader(dataset=dataset, **{**default_kwargs, **kwargs})
diff --git a/experiment/util/hydra_main.py b/experiment/util/hydra_main.py
new file mode 100644
index 00000000..ce55611b
--- /dev/null
+++ b/experiment/util/hydra_main.py
@@ -0,0 +1,229 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+import os
+import subprocess
+import sys
+from pathlib import Path
+from typing import Callable
+from typing import List
+from typing import NoReturn
+from typing import Optional
+
+import hydra
+from omegaconf import OmegaConf
+from omegaconf import DictConfig
+
+from experiment.util.path_utils import get_current_experiment_number
+from experiment.util.path_utils import make_current_experiment_dir
+from experiment.util.run_utils import log_error_and_exit
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# VARS #
+# ========================================================================= #
+
+
+# experiment/util/_hydra_searchpath_plugin_
+PLUGIN_NAMESPACE = os.path.abspath(os.path.join(__file__, '..', '_hydra_searchpath_plugin_'))
+
+# experiment/config
+EXP_CONFIG_DIR = os.path.abspath(os.path.join(__file__, '../..', 'config'))
+
+# list of configs
+_DISENT_CONFIG_DIRS: List[str] = None
+
+
+# ========================================================================= #
+# PATCHING #
+# ========================================================================= #
+
+
+def register_searchpath_plugin(
+ search_dir_main: str = EXP_CONFIG_DIR,
+ search_dirs_prepend: Optional[List[str]] = None,
+ search_dirs_append: Optional[List[str]] = None,
+):
+ """
+ Patch Hydra:
+ 1. sets the default search path to `experiment/config`
+ 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables
+ NOTE: --config-dir has lower priority than all these, --config-path has higher priority.
+
+ This function can safely be called multiple times
+ - unless other functions modify these same libs which is unlikely!
+ """
+ # normalise the config paths
+ if search_dirs_prepend is None: search_dirs_prepend = []
+ if search_dirs_append is None: search_dirs_append = []
+ assert isinstance(search_dirs_prepend, (tuple, list)) and all((isinstance(d, str) and d) for d in search_dirs_prepend), f'`search_dirs_prepend` must be a list or tuple of non-empty path strings to directories, got: {repr(search_dirs_prepend)}'
+ assert isinstance(search_dirs_append, (tuple, list)) and all((isinstance(d, str) and d) for d in search_dirs_append), f'`search_dirs_append` must be a list or tuple of non-empty path strings to directories, got: {repr(search_dirs_append)}'
+ assert isinstance(search_dir_main, str) and search_dir_main, f'`search_dir_main` must be a non-empty path string to a directory, got: {repr(search_dir_main)}'
+ # get dirs
+ config_dirs = [*search_dirs_prepend, search_dir_main, *search_dirs_append]
+
+ # check that it is the same as what has previously been registered, otherwise set the directories!
+ global _DISENT_CONFIG_DIRS
+ if _DISENT_CONFIG_DIRS is None:
+ _DISENT_CONFIG_DIRS = config_dirs
+ else:
+ assert _DISENT_CONFIG_DIRS == config_dirs, f'Config dirs have already been registered, on additional calls, registered dirs must be the same as previously values!\n- existing: {_DISENT_CONFIG_DIRS}\n- registered: {config_dirs}'
+
+ # register the experiment's search path plugin with disent, using hydras auto-detection
+ # of folders named `hydra_plugins` contained insided `namespace packages` or rather
+ # packages that are in the `PYTHONPATH` or `sys.path`
+ # 1. sets the default search path to those registered above
+ # 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables
+ # NOTE: --config-dir has lower priority than all these, --config-path has higher priority.
+ if PLUGIN_NAMESPACE not in sys.path:
+ sys.path.insert(0, PLUGIN_NAMESPACE)
+
+
+def register_hydra_resolvers():
+ """
+ Patch OmegaConf, enabling various config resolvers:
+ - enable the ${exit:} resolver for omegaconf/hydra
+ - enable the ${exp_num:} and ${exp_dir:,} resolvers to detect the experiment number
+ - enable the ${fmt:,} resolver that wraps `str.format`
+ - enable the ${abspath:} resolver that wraps `hydra.utils.to_absolute_path` formatting relative paths in relation to the original working directory
+ - enable the ${rsync_dir:/,/} resolver that returns `/`, but first rsync's the two directories!
+
+ This function can safely be called multiple times
+ - unless other functions modify these same libs which is unlikely!
+ """
+
+
+ # register a custom OmegaConf resolver that allows us to put in a ${exit:msg} that exits the program
+ # - if we don't register this, the program will still fail because we have an unknown
+ # resolver. This just prettifies the output.
+ if not OmegaConf.has_resolver('exit'):
+ class ConfigurationError(Exception):
+ pass
+ # resolver function
+ def _error_resolver(msg: str):
+ raise ConfigurationError(msg)
+ # patch omegaconf for hydra
+ OmegaConf.register_new_resolver('exit', _error_resolver)
+
+ # register a custom OmegaConf resolver that allows us to get the next experiment number from a directory
+ # - ${run_num:} returns the current experiment number
+ if not OmegaConf.has_resolver('exp_num'):
+ OmegaConf.register_new_resolver('exp_num', get_current_experiment_number)
+ # - ${run_dir:,} returns the current experiment folder with the name appended
+ if not OmegaConf.has_resolver('exp_dir'):
+ OmegaConf.register_new_resolver('exp_dir', make_current_experiment_dir)
+
+ # register a function that pads an integer to a specified length
+ # - ${fmt:"{:04d}",42} -> "0042"
+ if not OmegaConf.has_resolver('fmt'):
+ OmegaConf.register_new_resolver('fmt', str.format)
+
+ # register hydra helper functions
+ # - ${abspath:} convert a relative path to an abs path using the original hydra working directory, not the changed experiment dir.
+ if not OmegaConf.has_resolver('abspath'):
+ OmegaConf.register_new_resolver('abspath', hydra.utils.to_absolute_path)
+
+ # registry copy directory function
+ # - useful if datasets are already prepared on a shared drive and need to be copied to a temp drive for example!
+ if not OmegaConf.has_resolver('rsync_dir'):
+ def rsync_dir(src: str, dst: str) -> str:
+ src, dst = Path(src), Path(dst)
+ # checks
+ assert src.name and src.is_absolute(), f'src path must be absolute and not the root: {repr(str(src))}'
+ assert dst.name and dst.is_absolute(), f'dst path must be absolute and not the root: {repr(str(dst))}'
+ assert src.name == dst.name, f'src and dst paths must point to dirs with the same names: src.name={repr(src.name)}, dst.name={repr(dst.name)}'
+ # synchronize dirs
+ logging.info(f'rsync files:\n- src={repr(str(src))}\n- dst={repr(str(dst))}')
+ # create the parent dir and copy files into the parent
+ dst.parent.mkdir(parents=True, exist_ok=True)
+ returncode = subprocess.Popen(['rsync', '-avh', str(src), str(dst.parent)]).wait()
+ if returncode != 0:
+ raise RuntimeError('Failed to rsync files!')
+ # return the destination dir
+ return str(dst)
+ # REGISTER
+ OmegaConf.register_new_resolver('rsync_dir', rsync_dir)
+
+
+# ========================================================================= #
+# RUN HYDRA #
+# ========================================================================= #
+
+
+def patch_hydra(
+ # config search path
+ search_dir_main: str = EXP_CONFIG_DIR,
+ search_dirs_prepend: Optional[List[str]] = None,
+ search_dirs_append: Optional[List[str]] = None,
+):
+ # Patch Hydra and OmegaConf:
+ register_searchpath_plugin(search_dir_main=search_dir_main, search_dirs_prepend=search_dirs_prepend, search_dirs_append=search_dirs_append)
+ register_hydra_resolvers()
+
+
+def hydra_main(
+ callback: Callable[[DictConfig], NoReturn],
+ config_name: str = 'config',
+ # config search path
+ search_dir_main: str = EXP_CONFIG_DIR,
+ search_dirs_prepend: Optional[List[str]] = None,
+ search_dirs_append: Optional[List[str]] = None,
+ # logging
+ log_level: Optional[int] = logging.INFO,
+ log_exc_info_callback: bool = True,
+ log_exc_info_hydra: bool = False,
+):
+ # manually set log level before hydra initialises!
+ if log_level is not None:
+ logging.basicConfig(level=log_level)
+
+ # Patch Hydra and OmegaConf:
+ patch_hydra(search_dir_main=search_dir_main, search_dirs_prepend=search_dirs_prepend, search_dirs_append=search_dirs_append)
+
+ @hydra.main(config_path=None, config_name=config_name)
+ def _hydra_main(cfg: DictConfig):
+ try:
+ callback(cfg)
+ except Exception as e:
+ log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=log_exc_info_callback)
+ except:
+ log_error_and_exit(err_type='experiment error', err_msg='', exc_info=log_exc_info_callback)
+
+ try:
+ _hydra_main()
+ except KeyboardInterrupt as e:
+ log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False)
+ except Exception as e:
+ log_error_and_exit(err_type='hydra error', err_msg=str(e), exc_info=log_exc_info_hydra)
+ except:
+ log_error_and_exit(err_type='hydra error', err_msg='', exc_info=log_exc_info_hydra)
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/experiment/util/path_utils.py b/experiment/util/path_utils.py
new file mode 100644
index 00000000..57858aec
--- /dev/null
+++ b/experiment/util/path_utils.py
@@ -0,0 +1,150 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2021 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+import logging
+import os
+import re
+from pathlib import Path
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+
+log = logging.getLogger(__name__)
+
+
+# ========================================================================= #
+# PATH HELPERS #
+# ========================================================================= #
+
+
+_EXPERIMENT_SEP = '_'
+_EXPERIMENT_RGX = re.compile(f'^([0-9]+)({_EXPERIMENT_SEP}.+)?$')
+
+
+def get_max_experiment_number(root_dir: str, return_path: bool = False) -> Union[int, Tuple[int, Optional[str]]]:
+ """
+ Get the next experiment number in the specified directory. Experiment directories
+ all start with a numerical value.
+ - eg. "1", "00002", "3_name", "00042_name" are all valid subdirectories.
+ - eg. "name", "name_1", "name_00001", "99999_image.png" are all invalid and are
+ ignored. Either their name format is wrong or they are a file.
+
+ If all the above directories are all used as an example, then this function will
+ return the value 42 corresponding to "00042_name"
+ """
+ # check the dirs exist
+ if not os.path.exists(root_dir):
+ raise FileNotFoundError(f'The given experiments directory does not exist: {repr(root_dir)} ({repr(os.path.abspath(root_dir))})')
+ elif not os.path.isdir(root_dir):
+ raise NotADirectoryError(f'The given experiments path exists, but is not a directory: {repr(root_dir)} ({repr(os.path.abspath(root_dir))})')
+ # linear search over each file in the dir
+ max_num, max_path = 0, None
+ for file in os.listdir(root_dir):
+ # skip if not a directory
+ if not os.path.isdir(os.path.join(root_dir, file)):
+ continue
+ # skip if the file name does not match
+ match = _EXPERIMENT_RGX.search(file)
+ if not match:
+ continue
+ # update the maximum number
+ num, _ = match.groups()
+ num = int(num)
+ if num > max_num:
+ max_num, max_path = num, file
+ # done!
+ if return_path:
+ return max_num, max_path
+ return max_num
+
+
+_CURRENT_EXPERIMENT_NUM: Optional[int] = None
+_CURRENT_EXPERIMENT_DIR: Optional[str] = None
+
+
+def get_current_experiment_number(root_dir: str) -> int:
+ """
+ Get the next experiment number from the experiment directory, and cache
+ the result for future calls of this function for the current instance of the program.
+ - The next time the program is run, this value will differ.
+
+ For example, if the `root_dir` contains the directories: "00001_name", "00041", then
+ this function will return the next value which is `42` on all subsequent calls, even
+ if a directory for experiment 42 is created during the current program's lifetime.
+ """
+ global _CURRENT_EXPERIMENT_NUM
+ if _CURRENT_EXPERIMENT_NUM is None:
+ _CURRENT_EXPERIMENT_NUM = get_max_experiment_number(root_dir, return_path=False) + 1
+ return _CURRENT_EXPERIMENT_NUM
+
+
+def get_current_experiment_dir(root_dir: str, name: Optional[str] = None) -> str:
+ """
+ Like `get_current_experiment_number` which computes the next experiment number, this
+ function computes the next experiment path, which appends a name to the computed number.
+
+ The result is cached for the lifetime of the program, however, on subsequent calls of
+ this function, the computed name must always match the original value otherwise an
+ error is thrown! This is to prevent experiments with duplicate numbers from being created!
+ """
+ if name is not None:
+ assert Path(name).name == name, f'The given name is not valid: {repr(name)}'
+ # make the dirname & normalise the path
+ num = get_current_experiment_number(root_dir)
+ dir_name = f'{num:05d}{_EXPERIMENT_SEP}{name}' if name else f'{num:05d}'
+ exp_dir = os.path.abspath(os.path.join(root_dir, dir_name))
+ # cache the experiment name or check against the existing cache
+ global _CURRENT_EXPERIMENT_DIR
+ if _CURRENT_EXPERIMENT_DIR is None:
+ _CURRENT_EXPERIMENT_DIR = exp_dir
+ if exp_dir != _CURRENT_EXPERIMENT_DIR:
+ raise RuntimeError(f'Current experiment directory has already been set: {repr(_CURRENT_EXPERIMENT_DIR)} This does not match what was computed: {repr(exp_dir)}')
+ # done!
+ return _CURRENT_EXPERIMENT_DIR
+
+
+def make_current_experiment_dir(root_dir: str, name: Optional[str] = None) -> str:
+ """
+ Like `get_current_experiment_dir`, but create any of the directories if needed.
+ - Both the `root_dir` and the computed subdir for the current experiment will be created.
+ """
+ root_dir = os.path.abspath(root_dir)
+ # make the root directory if it does not exist!
+ if not os.path.exists(root_dir):
+ log.info(f'root experiments directory does not exist, creating... {repr(root_dir)}')
+ os.makedirs(root_dir, exist_ok=True)
+ # get the current dir
+ current_dir = get_current_experiment_dir(root_dir, name)
+ # make the current dir
+ if not os.path.exists(current_dir):
+ log.info(f'current experiment directory does not exist, creating... {repr(current_dir)}')
+ os.makedirs(current_dir, exist_ok=True)
+ # done!
+ return current_dir
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/requirements-test.txt b/requirements-test.txt
index af5c6933..79de0ba6 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -2,4 +2,21 @@ pytest>=6.2.4
pytest-cov>=2.12.1
# requires pytorch to be installed first (duplicated in requirements-experiment.txt)
+# - we need `nvcc` to be installed first, otherwise GPU kernel extensions will not be
+# compiled and this error will silently be skipped. If you get an error such as
+# $ conda install -c nvidia cuda-nvcc
+# - Make sure that the version of torch corresponds to the version of `nvcc`, torch needs
+# to be compiled with the same version! Install the correct version from:
+# https://pytorch.org/get-started/locally/ By default torch compiled with 10.2 is installed,
+# but `nvcc` will probably want to install 11.
+# CUDA 10.2 (as of 2022-03-15) EITHER OF:
+# $ pip3 install torch torchvision torchaudio
+# $ conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
+# CUDA 11.3 (as of 2022-03-15) EITHER OF:
+# $ pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+# $ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
+# - I personally just manage my cuda version manually, installing the correct cudatoolkit from: https://developer.nvidia.com/cuda-toolkit-archive
+# Then making sure that:
+# PATH contains: "/usr/local/cuda/bin"
+# LD_LIBRARY_PATH contains: "/usr/local/cuda/lib64"
torchsort>=0.1.4
diff --git a/setup.py b/setup.py
index 99db0047..216eeb25 100644
--- a/setup.py
+++ b/setup.py
@@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="NathanJMichlo@gmail.com",
- version="0.3.4",
+ version="0.4.0",
python_requires=">=3.8", # we make use of standard library features only in 3.8
packages=setuptools.find_packages(),
diff --git a/tests/test_experiment.py b/tests/test_experiment.py
index 0ec7593f..fec35bba 100644
--- a/tests/test_experiment.py
+++ b/tests/test_experiment.py
@@ -22,13 +22,15 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+import logging
import os
import os.path
-import hydra
import pytest
import experiment.run as experiment_run
+from experiment.util.hydra_main import hydra_main
+from tests.util import temp_environ
from tests.util import temp_sys_args
@@ -37,19 +39,28 @@
# ========================================================================= #
-@pytest.mark.parametrize('args', [
- ['run_action=skip'],
- ['run_action=prepare_data'],
- ['run_action=train'],
+@pytest.mark.parametrize(('env', 'args'), [
+ # test the standard configs
+ (dict(), ['run_action=skip']),
+ (dict(), ['run_action=prepare_data']),
+ (dict(), ['run_action=train']),
])
-def test_experiment_run(args):
+def test_experiment_run(env, args):
+ # show full errors in hydra
os.environ['HYDRA_FULL_ERROR'] = '1'
- # TODO: why does this not work when config_path is absolute?
- # ie. config_path=os.path.join(os.path.dirname(experiment_run.__file__), 'config')
- with temp_sys_args([experiment_run.__file__, *args]):
- hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.run_action)
- hydra_main()
+ # temporarily set the environment and the arguments
+ with temp_environ(env), temp_sys_args([experiment_run.__file__, *args]):
+ # run the hydra experiment
+ # 1. sets the default search path to `experiment/config`
+ # 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables
+ # 3. enable the ${exit:} and various other resolvers for omegaconf/hydra
+ hydra_main(
+ callback=experiment_run.run_action,
+ config_name='config_test',
+ log_level=logging.DEBUG,
+ )
+
# ========================================================================= #
diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py
index 91c1750c..f454050f 100644
--- a/tests/test_frameworks.py
+++ b/tests/test_frameworks.py
@@ -22,6 +22,7 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+import pickle
from dataclasses import asdict
from functools import partial
@@ -47,7 +48,7 @@
# ========================================================================= #
-@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], [
+_TEST_FRAMEWORKS = [
# AE - unsupervised
(Ae, dict(), XYObjectData),
# AE - weakly supervised
@@ -69,8 +70,11 @@
(AdaGVaeMinimal, dict(), XYObjectData),
# VAE - supervised
(TripletVae, dict(), XYObjectData),
- (TripletVae, dict(disable_decoder=True, disable_reg_loss=True, disable_posterior_scale=0.5), XYObjectData),
-])
+ (TripletVae, dict(detach_decoder=True, disable_reg_loss=True), XYObjectData),
+]
+
+
+@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], _TEST_FRAMEWORKS)
def test_frameworks(Framework, cfg_kwargs, Data):
DataSampler = {
1: GroundTruthSingleSampler,
@@ -90,9 +94,29 @@ def test_frameworks(Framework, cfg_kwargs, Data):
cfg=Framework.cfg(**cfg_kwargs)
)
+ # test pickling before training
+ pickle.dumps(framework)
+
+ # train!
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=True)
trainer.fit(framework, dataloader)
+ # test pickling after training, something may have changed!
+ pickle.dumps(framework)
+
+
+@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], _TEST_FRAMEWORKS)
+def test_framework_pickling(Framework, cfg_kwargs, Data):
+ framework = Framework(
+ model=AutoEncoder(
+ encoder=EncoderLinear(x_shape=(64, 64, 3), z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1),
+ decoder=DecoderLinear(x_shape=(64, 64, 3), z_size=6),
+ ),
+ cfg=Framework.cfg(**cfg_kwargs)
+ )
+ # test pickling!
+ pickle.dumps(framework)
+
def test_framework_config_defaults():
# import torch
@@ -102,8 +126,7 @@ def test_framework_config_defaults():
optimizer_kwargs=None,
recon_loss='mse',
disable_aug_loss=False,
- disable_decoder=False,
- disable_posterior_scale=None,
+ detach_decoder=False,
disable_rec_loss=False,
disable_reg_loss=False,
loss_reduction='mean',
@@ -116,8 +139,7 @@ def test_framework_config_defaults():
optimizer_kwargs=None,
recon_loss='bce',
disable_aug_loss=False,
- disable_decoder=False,
- disable_posterior_scale=None,
+ detach_decoder=False,
disable_rec_loss=False,
disable_reg_loss=False,
loss_reduction='mean',
diff --git a/tests/test_math.py b/tests/test_math.py
index 69c2e4cd..e4962f33 100644
--- a/tests/test_math.py
+++ b/tests/test_math.py
@@ -37,11 +37,16 @@
from disent.nn.functional import torch_cov_matrix
from disent.nn.functional import torch_dct
from disent.nn.functional import torch_dct2
+from disent.nn.functional import torch_dist_hamming
from disent.nn.functional import torch_gaussian_kernel_2d
from disent.nn.functional import torch_idct
from disent.nn.functional import torch_idct2
from disent.nn.functional import torch_mean_generalized
+from disent.nn.functional import torch_norm
+from disent.nn.functional import torch_dist
from disent.dataset.transform import ToImgTensorF32
+from disent.nn.functional import torch_norm_euclidean
+from disent.nn.functional import torch_norm_manhattan
from disent.util import to_numpy
@@ -84,12 +89,72 @@ def test_generalised_mean():
assert torch.allclose(torch_mean_generalized(xs, p=-1), torch.as_tensor(hmean(xs, axis=None))) # scipy default axis is 0
# min max
- assert torch.allclose(torch_mean_generalized(xs, p='maximum', dim=1), torch.max(xs, dim=1).values)
- assert torch.allclose(torch_mean_generalized(xs, p='minimum', dim=1), torch.min(xs, dim=1).values)
- assert torch.allclose(torch_mean_generalized(xs, p=np.inf, dim=1), torch.max(xs, dim=1).values)
- assert torch.allclose(torch_mean_generalized(xs, p=-np.inf, dim=1), torch.min(xs, dim=1).values)
+ assert torch.allclose(torch_mean_generalized(xs, p='maximum', dim=1), torch.amax(xs, dim=1))
+ assert torch.allclose(torch_mean_generalized(xs, p='minimum', dim=1), torch.amin(xs, dim=1))
+ assert torch.allclose(torch_mean_generalized(xs, p=np.inf, dim=1), torch.amax(xs, dim=1))
+ assert torch.allclose(torch_mean_generalized(xs, p=-np.inf, dim=1), torch.amin(xs, dim=1))
+def test_p_norm():
+ xs = torch.abs(torch.randn(2, 1000, 3, dtype=torch.float64))
+
+ inf = float('inf')
+
+ # torch equivalents
+ assert torch.allclose(torch_norm(xs, p=1, dim=-1), torch.linalg.norm(xs, ord=1, dim=-1))
+ assert torch.allclose(torch_norm(xs, p=2, dim=-1), torch.linalg.norm(xs, ord=2, dim=-1))
+ assert torch.allclose(torch_norm(xs, p='maximum', dim=-1), torch.linalg.norm(xs, ord=inf, dim=-1))
+
+ # torch equivalents -- less than zero [FAIL]
+ with pytest.raises(ValueError, match='p-norm cannot have a p value less than 1'):
+ assert torch.allclose(torch_norm(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1))
+ with pytest.raises(ValueError, match='p-norm cannot have a p value less than 1'):
+ assert torch.allclose(torch_norm(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1))
+
+ # torch equivalents -- less than zero
+ assert torch.allclose(torch_dist(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1))
+ assert torch.allclose(torch_dist(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1))
+ assert torch.allclose(torch_dist(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1))
+ assert torch.allclose(torch_dist(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1))
+
+ # test other axes
+ ys = torch.flatten(torch.moveaxis(xs, 0, -1), start_dim=1, end_dim=-1)
+ assert torch.allclose(torch_norm(xs, p=2, dim=[0, -1]), torch.linalg.norm(ys, ord=2, dim=-1))
+ ys = torch.flatten(xs, start_dim=1, end_dim=-1)
+ assert torch.allclose(torch_norm(xs, p=1, dim=[-2, -1]), torch.linalg.norm(ys, ord=1, dim=-1))
+ ys = torch.flatten(torch.moveaxis(xs, -1, 0), start_dim=1, end_dim=-1)
+ assert torch.allclose(torch_dist(xs, p=0, dim=[0, 1]), torch.linalg.norm(ys, ord=0, dim=-1))
+
+ # check equal names
+ assert torch.allclose(torch_dist(xs, dim=1, p='euclidean'), torch_norm_euclidean(xs, dim=1))
+ assert torch.allclose(torch_dist(xs, dim=1, p=2), torch_norm_euclidean(xs, dim=1))
+ assert torch.allclose(torch_dist(xs, dim=1, p='manhattan'), torch_norm_manhattan(xs, dim=1))
+ assert torch.allclose(torch_dist(xs, dim=1, p=1), torch_norm_manhattan(xs, dim=1))
+ assert torch.allclose(torch_dist(xs, dim=1, p='hamming'), torch_dist_hamming(xs, dim=1))
+ assert torch.allclose(torch_dist(xs, dim=1, p=0), torch_dist_hamming(xs, dim=1))
+
+ # check axes
+ assert torch_dist(xs, dim=1, p=2, keepdim=False).shape == (2, 3)
+ assert torch_dist(xs, dim=1, p=2, keepdim=True).shape == (2, 1, 3)
+ assert torch_dist(xs, dim=-1, p=1, keepdim=False).shape == (2, 1000)
+ assert torch_dist(xs, dim=-1, p=1, keepdim=True).shape == (2, 1000, 1)
+ assert torch_dist(xs, dim=0, p=0, keepdim=False).shape == (1000, 3)
+ assert torch_dist(xs, dim=0, p=0, keepdim=True).shape == (1, 1000, 3)
+ assert torch_dist(xs, dim=[0, -1], p=-inf, keepdim=False).shape == (1000,)
+ assert torch_dist(xs, dim=[0, -1], p=-inf, keepdim=True).shape == (1, 1000, 1)
+ assert torch_dist(xs, dim=[0, 1], p=inf, keepdim=False).shape == (3,)
+ assert torch_dist(xs, dim=[0, 1], p=inf, keepdim=True).shape == (1, 1, 3)
+
+ # check norm over all
+ assert torch_dist(xs, dim=None, p=0, keepdim=False).shape == ()
+ assert torch_dist(xs, dim=None, p=0, keepdim=True).shape == (1, 1, 1)
+ assert torch_dist(xs, dim=None, p=1, keepdim=False).shape == ()
+ assert torch_dist(xs, dim=None, p=1, keepdim=True).shape == (1, 1, 1)
+ assert torch_dist(xs, dim=None, p=2, keepdim=False).shape == ()
+ assert torch_dist(xs, dim=None, p=2, keepdim=True).shape == (1, 1, 1)
+ assert torch_dist(xs, dim=None, p=inf, keepdim=False).shape == ()
+ assert torch_dist(xs, dim=None, p=inf, keepdim=True).shape == (1, 1, 1)
+
def test_dct():
x = torch.randn(128, 3, 64, 32, dtype=torch.float64)
@@ -141,7 +206,7 @@ def test_fft_conv2d():
kernel = torch_gaussian_kernel_2d(sigma=i)
out_cnv = torch_conv2d_channel_wise(signal=batch, kernel=kernel)[0]
out_fft = torch_conv2d_channel_wise_fft(signal=batch, kernel=kernel)[0]
- assert torch.max(torch.abs(out_cnv - out_fft)) < 1e-6
+ assert torch.amax(torch.abs(out_cnv - out_fft)) < 1e-6
# ========================================================================= #
diff --git a/tests/test_registry.py b/tests/test_registry.py
index c173d969..1bde1aa3 100644
--- a/tests/test_registry.py
+++ b/tests/test_registry.py
@@ -22,8 +22,8 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
-
-from disent.registry import REGISTRIES
+import pytest
+import disent.registry as R
# ========================================================================= #
@@ -32,31 +32,28 @@
COUNTS = {
- 'DATASETS': 6,
+ 'DATASETS': 10,
'SAMPLERS': 8,
'FRAMEWORKS': 10,
- 'RECON_LOSSES': 6,
- 'LATENT_DISTS': 2,
+ 'RECON_LOSSES': 9,
+ 'LATENT_HANDLERS': 2,
'OPTIMIZERS': 30,
'METRICS': 5,
'SCHEDULES': 5,
'MODELS': 8,
+ 'KERNELS': 2,
}
-
-def test_registry_loading():
+@pytest.mark.parametrize('registry_key', COUNTS.keys())
+def test_registry_loading(registry_key):
# load everything and check the counts
- total = 0
- for registry in REGISTRIES:
- count = 0
- for name in REGISTRIES[registry]:
- loaded = REGISTRIES[registry][name]
- count += 1
- total += 1
- assert COUNTS[registry] == count, f'invalid count for: {registry}'
- assert total == sum(COUNTS.values()), f'invalid total'
+ count = 0
+ for example in R.REGISTRIES[registry_key]:
+ loaded = R.REGISTRIES[registry_key][example]
+ count += 1
+ assert count == COUNTS[registry_key], f'invalid count for: {registry_key}'
# ========================================================================= #
diff --git a/tests/test_to_img.py b/tests/test_to_img.py
new file mode 100644
index 00000000..5ecd68ae
--- /dev/null
+++ b/tests/test_to_img.py
@@ -0,0 +1,328 @@
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+# MIT License
+#
+# Copyright (c) 2022 Nathan Juraj Michlo
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+
+from disent.util.visualize.vis_img import _torch_to_images_normalise_args
+from disent.util.visualize.vis_img import numpy_to_images
+from disent.util.visualize.vis_img import numpy_to_pil_images
+from disent.util.visualize.vis_img import torch_to_images
+from disent.util.visualize.vis_img import _ALLOWED_DTYPES
+
+
+# ========================================================================= #
+# Tests #
+# ========================================================================= #
+
+
+def test_torch_to_images_basic():
+ inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
+ inp_uint8 = (inp_float * 127 + 63).to(torch.uint8)
+ # check runs
+ out = torch_to_images(inp_float)
+ assert out.dtype == torch.uint8
+ out = torch_to_images(inp_uint8)
+ assert out.dtype == torch.uint8
+ out = torch_to_images(inp_float, in_dtype=None, out_dtype=None)
+ assert out.dtype == inp_float.dtype
+ out = torch_to_images(inp_uint8, in_dtype=None, out_dtype=None)
+ assert out.dtype == inp_uint8.dtype
+
+
+def test_torch_to_images_permutations():
+ inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
+ inp_uint8 = (inp_float * 127 + 63).to(torch.uint8)
+
+ # general checks
+ def check_all(inputs, in_dtype=None):
+ float_results, int_results = [], []
+ for out_dtype in _ALLOWED_DTYPES:
+ out = torch_to_images(inputs, in_dtype=in_dtype, out_dtype=out_dtype)
+ stats = torch.stack([out.min().to(torch.float64), out.max().to(torch.float64), out.to(dtype=torch.float64).mean()])
+ (float_results if out_dtype.is_floating_point else int_results).append(stats)
+ for a, b in zip(float_results[:-1], float_results[1:]): assert torch.allclose(a, b)
+ for a, b in zip(int_results[:-1], int_results[1:]): assert torch.allclose(a, b)
+
+ # check type permutations
+ check_all(inp_float, torch.float32)
+ check_all(inp_uint8, torch.uint8)
+
+
+def test_torch_to_images_preserve_type():
+ for dtype in _ALLOWED_DTYPES:
+ tensor = (torch.rand(8, 3, 64, 64) * (1 if dtype.is_floating_point else 255)).to(dtype)
+ out = torch_to_images(tensor, in_dtype=dtype, out_dtype=dtype, clamp_mode='warn')
+ assert out.dtype == dtype
+
+
+def test_torch_to_images_arg_helper():
+ assert _torch_to_images_normalise_args((64, 128, 3), torch.uint8, 'HWC', 'CHW', None, None) == ((-1, -3, -2), torch.uint8, torch.uint8, -3)
+ assert _torch_to_images_normalise_args((64, 128, 3), torch.uint8, 'HWC', 'HWC', None, None) == ((-3, -2, -1), torch.uint8, torch.uint8, -1)
+
+
+def test_torch_to_images_invalid_args():
+ inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32)
+
+ # check tensor
+ with pytest.raises(TypeError, match="images must be of type"):
+ torch_to_images(tensor=None)
+ with pytest.raises(ValueError, match='dim "C", required: 1 or 3'):
+ torch_to_images(tensor=torch.rand(8, 2, 16, 16, dtype=torch.float32))
+ with pytest.raises(ValueError, match='dim "C", required: 1 or 3'):
+ torch_to_images(tensor=torch.rand(8, 16, 16, 3, dtype=torch.float32))
+ with pytest.raises(ValueError, match='images must have 3 or more dimensions corresponding to'):
+ torch_to_images(tensor=torch.rand(16, 16, dtype=torch.float32))
+
+ # check dims
+ with pytest.raises(TypeError, match="in_dims must be of type"):
+ torch_to_images(inp_float, in_dims=None)
+ with pytest.raises(TypeError, match="out_dims must be of type"):
+ torch_to_images(inp_float, out_dims=None)
+ with pytest.raises(KeyError, match="in_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"):
+ torch_to_images(inp_float, in_dims='INVALID')
+ with pytest.raises(KeyError, match="out_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"):
+ torch_to_images(inp_float, out_dims='INVALID')
+ with pytest.raises(KeyError, match="in_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"):
+ torch_to_images(inp_float, in_dims='CHWW')
+ with pytest.raises(KeyError, match="out_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"):
+ torch_to_images(inp_float, out_dims='CHWW')
+
+ # check dtypes
+ with pytest.raises(TypeError, match="images dtype: torch.float32 does not match in_dtype: torch.uint8"):
+ torch_to_images(inp_float, in_dtype=torch.uint8)
+ with pytest.raises(TypeError, match='in_dtype is not allowed'):
+ torch_to_images(inp_float, in_dtype=torch.complex64)
+ with pytest.raises(TypeError, match='out_dtype is not allowed'):
+ torch_to_images(inp_float, out_dtype=torch.complex64)
+ with pytest.raises(TypeError, match='in_dtype is not allowed'):
+ torch_to_images(inp_float, in_dtype=torch.float16)
+ with pytest.raises(TypeError, match='out_dtype is not allowed'):
+ torch_to_images(inp_float, out_dtype=torch.float16)
+
+
+def _check(target_shape, target_dtype, img, m=None, M=None):
+ assert img.dtype == target_dtype
+ assert img.shape == target_shape
+ if isinstance(img, torch.Tensor): img = img.numpy()
+ if m is not None: assert np.isclose(img.min(), m), f'min mismatch: {img.min()} (actual) != {m} (expected)'
+ if M is not None: assert np.isclose(img.max(), M), f'max mismatch: {img.max()} (actual) != {M} (expected)'
+
+
+def test_torch_to_images_adv():
+ # CHW
+ nchw_float = torch.rand(8, 3, 64, 32, dtype=torch.float32)
+ nchw_uint8 = torch.randint(0, 255, (8, 3, 64, 32), dtype=torch.uint8)
+ # HWC
+ nhwc_float = torch.rand(8, 64, 32, 3, dtype=torch.float32)
+ nhwc_uint8 = torch.randint(0, 255, (8, 64, 32, 3), dtype=torch.uint8)
+
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float)) # make sure default for numpy is CHW
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8)) # make sure default for numpy is CHW
+
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC'))
+
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW', 'HWC'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'HWC'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC'))
+
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_float, 'CHW', 'CHW'))
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW'))
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'CHW'))
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW'))
+
+ # random permute
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW'))
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW'))
+ _check((8, 3, 32, 64), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CWH'))
+ _check((8, 3, 32, 64), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CWH'))
+
+ _check((8, 64, 3, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HCW'))
+ _check((8, 64, 3, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HCW'))
+ _check((8, 32, 3, 64), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'WCH'))
+ _check((8, 32, 3, 64), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'WCH'))
+
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC'))
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC'))
+ _check((8, 32, 64, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'WHC'))
+ _check((8, 32, 64, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'WHC'))
+
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float32))
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float32))
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float32))
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float32))
+
+ _check((8, 64, 32, 3), torch.float64, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float64))
+ _check((8, 64, 32, 3), torch.float64, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float64))
+ _check((8, 64, 32, 3), torch.float64, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float64))
+ _check((8, 64, 32, 3), torch.float64, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float64))
+
+ # random, but chance of this failing is almost impossible
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255)
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255)
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255)
+ _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255)
+
+ # random, but chance of this failing is almost impossible
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1)
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1)
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1)
+ _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1)
+
+ # random, but chance of this failing is almost impossible
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_float, 'CHW', 'CHW', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255)
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255)
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'CHW', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255)
+ _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255)
+
+ # random, but chance of this failing is almost impossible
+ _check((8, 3, 64, 32), torch.float32, torch_to_images(nchw_float, 'CHW', 'CHW', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1)
+ _check((8, 3, 64, 32), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'CHW', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1)
+ _check((8, 3, 64, 32), torch.float32, torch_to_images(nhwc_float, 'HWC', 'CHW', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1)
+ _check((8, 3, 64, 32), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'CHW', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1)
+
+ # check clamping
+ with pytest.raises(ValueError, match='is outside of the required range'):
+ torch_to_images(nchw_float, 'CHW', out_dtype=torch.float32, clamp_mode='error', in_min=0.25, in_max=0.75)
+ with pytest.raises(ValueError, match='is outside of the required range'):
+ torch_to_images(nchw_uint8, 'CHW', out_dtype=torch.float32, clamp_mode='error', in_min=64, in_max=192)
+ with pytest.raises(ValueError, match='is outside of the required range'):
+ torch_to_images(nhwc_float, 'HWC', out_dtype=torch.float32, clamp_mode='error', in_min=0.25, in_max=0.75)
+ with pytest.raises(ValueError, match='is outside of the required range'):
+ torch_to_images(nhwc_uint8, 'HWC', out_dtype=torch.float32, clamp_mode='error', in_min=64, in_max=192)
+ with pytest.raises(KeyError, match="invalid clamp mode: 'asdf'"):
+ torch_to_images(nhwc_uint8, 'HWC', out_dtype=torch.float32, clamp_mode='asdf', in_min=64, in_max=192)
+
+
+def test_numpy_to_pil_image():
+ # CHW
+ nchw_float = np.random.rand(8, 3, 64, 32)
+ nchw_uint8 = np.random.randint(0, 255, (8, 3, 64, 32), dtype='uint8')
+
+ # HWC
+ nhwc_float = np.random.rand(8, 64, 32, 3)
+ nhwc_uint8 = np.random.randint(0, 255, (8, 64, 32, 3), dtype='uint8')
+
+ with pytest.raises(ValueError, match='images do not have the correct number of channels for dim "C"'):
+ numpy_to_images(nchw_float) # make sure default for numpy is HWC
+ with pytest.raises(ValueError, match='images do not have the correct number of channels for dim "C"'):
+ numpy_to_images(nchw_uint8) # make sure default for numpy is HWC
+
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float)) # make sure default for numpy is HWC
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8)) # make sure default for numpy is HWC
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW'))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW'))
+
+ for pil_images in [
+ numpy_to_pil_images(nchw_float, 'CHW'),
+ numpy_to_pil_images(nhwc_float),
+ ]:
+ assert isinstance(pil_images, np.ndarray)
+ assert pil_images.shape == (8,)
+ for pil_image in pil_images:
+ pil_image: Image.Image
+ assert isinstance(pil_image, Image.Image)
+ assert pil_image.width == 32
+ assert pil_image.height == 64
+
+ # single image should be returned as an array of shape ()
+ pil_image: np.ndarray = numpy_to_pil_images(np.random.rand(64, 32, 3))
+ assert pil_image.shape == ()
+ assert isinstance(pil_image, np.ndarray)
+ assert isinstance(pil_image.tolist(), Image.Image)
+
+ # check arb size
+ pil_images: np.ndarray = numpy_to_pil_images(np.random.rand(4, 5, 2, 16, 32, 3))
+ assert pil_images.shape == (4, 5, 2)
+
+
+def test_numpy_image_min_max():
+ # CHW
+ nchw_float = np.random.rand(8, 3, 64, 32)
+ nchw_uint8 = np.random.randint(0, 255, (8, 3, 64, 32), dtype='uint8')
+
+ # HWC
+ nhwc_float = np.random.rand(8, 64, 32, 3)
+ nhwc_uint8 = np.random.randint(0, 255, (8, 64, 32, 3), dtype='uint8')
+
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=0, in_max=1))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=0, in_max=255))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=0, in_max=1))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=0, in_max=255))
+
+ # OUT: HWC
+
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=(0, 0, 0), in_max=(1, 1, 1)))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=(0, 0, 0), in_max=(255, 255, 255)))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1)))
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255)))
+
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=(0,), in_max=(1,))) # should maybe disable this from working?
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=(0,), in_max=(255,))) # should maybe disable this from working?
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working?
+ _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working?
+
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', in_min=(0,), in_max=(1,)))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', in_min=(0,), in_max=(255,)))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', in_min=(0,), in_max=(1,)))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', in_min=(0,), in_max=(255,)))
+
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', in_min=0, in_max=1))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', in_min=0, in_max=255))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', in_min=0, in_max=1))
+ _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', in_min=0, in_max=255))
+
+ # OUT: CHW
+
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_float, 'HWC', 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1)))
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255)))
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_float, 'CHW', 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1)))
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_uint8, 'CHW', 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255)))
+
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_float, 'HWC', 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working?
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working?
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_float, 'CHW', 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working?
+ _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_uint8, 'CHW', 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working?
+
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', 'CHW', in_min=(0,), in_max=(1,)))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', 'CHW', in_min=(0,), in_max=(255,)))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', 'CHW', in_min=(0,), in_max=(1,)))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', 'CHW', in_min=(0,), in_max=(255,)))
+
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', 'CHW', in_min=0, in_max=1))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', 'CHW', in_min=0, in_max=255))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', 'CHW', in_min=0, in_max=1))
+ _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', 'CHW', in_min=0, in_max=255))
+
+
+# ========================================================================= #
+# END #
+# ========================================================================= #
diff --git a/tests/util.py b/tests/util.py
index bcff6941..1bc152c8 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -21,10 +21,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
+
import contextlib
import os
import sys
from contextlib import contextmanager
+from typing import Any
+from typing import Dict
# ========================================================================= #
@@ -58,12 +61,35 @@ def temp_wd(new_wd):
@contextlib.contextmanager
def temp_sys_args(new_argv):
+ # TODO: should this copy values?
old_argv = sys.argv
sys.argv = new_argv
yield
sys.argv = old_argv
+@contextmanager
+def temp_environ(environment: Dict[str, Any]):
+ # TODO: should this copy values? -- could use unittest.mock.patch.dict(...)
+ # save the old environment
+ existing_env = {}
+ for k in environment:
+ if k in os.environ:
+ existing_env[k] = os.environ[k]
+ # update the environment
+ os.environ.update(environment)
+ # run the context
+ try:
+ yield
+ finally:
+ # restore the original environment
+ for k in environment:
+ if k in existing_env:
+ os.environ[k] = existing_env[k]
+ else:
+ del os.environ[k]
+
+
# ========================================================================= #
# END #
# ========================================================================= #